21 #ifndef otbKMeansImageClassificationFilter_hxx
22 #define otbKMeansImageClassificationFilter_hxx
26 #include "itkImageRegionIterator.h"
27 #include "itkNumericTraits.h"
34 template <
class TInputImage,
class TOutputImage,
unsigned int VMaxSampleDimension,
class TMaskImage>
37 this->SetNumberOfRequiredInputs(2);
38 this->SetNumberOfRequiredInputs(1);
39 m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue();
43 template <
class TInputImage,
class TOutputImage,
unsigned int VMaxSampleDimension,
class TMaskImage>
46 this->itk::ProcessObject::SetNthInput(1,
const_cast<MaskImageType*
>(mask));
49 template <
class TInputImage,
class TOutputImage,
unsigned int VMaxSampleDimension,
class TMaskImage>
53 if (this->GetNumberOfInputs() < 2)
57 return static_cast<const MaskImageType*
>(this->itk::ProcessObject::GetInput(1));
60 template <
class TInputImage,
class TOutputImage,
unsigned int VMaxSampleDimension,
class TMaskImage>
63 unsigned int sample_size = MaxSampleDimension;
64 unsigned int nb_classes = m_Centroids.Size() / sample_size;
66 for (LabelType label = 1; label <= static_cast<LabelType>(nb_classes); ++label)
70 m_CentroidsMap[label] = new_centroid;
72 for (
unsigned int i = 0; i < MaxSampleDimension; ++i)
74 m_CentroidsMap[label][i] =
static_cast<ValueType
>(m_Centroids[MaxSampleDimension * (
static_cast<unsigned int>(label) - 1) + i]);
79 template <
class TInputImage,
class TOutputImage,
unsigned int VMaxSampleDimension,
class TMaskImage>
81 const OutputImageRegionType& outputRegionForThread)
83 InputImageConstPointerType inputPtr = this->GetInput();
84 MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
85 OutputImagePointerType outputPtr = this->GetOutput();
87 typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
88 typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
89 typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
91 InputIteratorType inIt(inputPtr, outputRegionForThread);
92 OutputIteratorType outIt(outputPtr, outputRegionForThread);
94 MaskIteratorType maskIt;
97 maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
100 unsigned int maxDimension = SampleType::Dimension;
101 unsigned int sampleSize = std::min(inputPtr->GetNumberOfComponentsPerPixel(), maxDimension);
103 bool validPoint =
true;
105 while (!outIt.IsAtEnd())
107 outIt.Set(m_DefaultLabel);
115 typename DistanceType::Pointer distance = DistanceType::New();
117 while (!outIt.IsAtEnd() && (!inIt.IsAtEnd()))
121 validPoint = maskIt.Get() > 0;
127 LabelType current_label = 1;
130 for (
unsigned int i = 0; i < sampleSize; ++i)
132 pixel[i] = inIt.Get()[i];
135 double current_distance = distance->Evaluate(pixel, m_CentroidsMap[label]);
137 for (label = 2; label <= static_cast<LabelType>(m_CentroidsMap.size()); ++label)
139 double tmp_dist = distance->Evaluate(pixel, m_CentroidsMap[label]);
140 if (tmp_dist < current_distance)
142 current_label = label;
143 current_distance = tmp_dist;
146 outIt.Set(current_label);
void SetInputMask(const MaskImageType *mask)
void DynamicThreadedGenerateData(const OutputImageRegionType &outputRegionForThread) override
KMeansImageClassificationFilter()
const MaskImageType * GetInputMask(void)
void BeforeThreadedGenerateData() override
std::vector< double > SampleType
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.