21 #ifndef otbKMeansImageClassificationFilter_hxx
22 #define otbKMeansImageClassificationFilter_hxx
25 #include "itkImageRegionIterator.h"
26 #include "itkNumericTraits.h"
33 template <
class TInputImage,
class TOutputImage,
unsigned int VMaxSampleDimension,
class TMaskImage>
36 this->SetNumberOfRequiredInputs(2);
37 this->SetNumberOfRequiredInputs(1);
38 m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue();
42 template <
class TInputImage,
class TOutputImage,
unsigned int VMaxSampleDimension,
class TMaskImage>
45 this->itk::ProcessObject::SetNthInput(1,
const_cast<MaskImageType*
>(mask));
48 template <
class TInputImage,
class TOutputImage,
unsigned int VMaxSampleDimension,
class TMaskImage>
52 if (this->GetNumberOfInputs() < 2)
56 return static_cast<const MaskImageType*
>(this->itk::ProcessObject::GetInput(1));
59 template <
class TInputImage,
class TOutputImage,
unsigned int VMaxSampleDimension,
class TMaskImage>
62 unsigned int sample_size = MaxSampleDimension;
63 unsigned int nb_classes = m_Centroids.Size() / sample_size;
65 for (LabelType label = 1; label <= static_cast<LabelType>(nb_classes); ++label)
69 m_CentroidsMap[label] = new_centroid;
71 for (
unsigned int i = 0; i < MaxSampleDimension; ++i)
73 m_CentroidsMap[label][i] =
static_cast<ValueType
>(m_Centroids[MaxSampleDimension * (
static_cast<unsigned int>(label) - 1) + i]);
78 template <
class TInputImage,
class TOutputImage,
unsigned int VMaxSampleDimension,
class TMaskImage>
80 const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType itkNotUsed(threadId))
82 InputImageConstPointerType inputPtr = this->GetInput();
83 MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
84 OutputImagePointerType outputPtr = this->GetOutput();
86 typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
87 typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
88 typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
90 InputIteratorType inIt(inputPtr, outputRegionForThread);
91 OutputIteratorType outIt(outputPtr, outputRegionForThread);
93 MaskIteratorType maskIt;
96 maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
99 unsigned int maxDimension = SampleType::Dimension;
100 unsigned int sampleSize = std::min(inputPtr->GetNumberOfComponentsPerPixel(), maxDimension);
102 bool validPoint =
true;
104 while (!outIt.IsAtEnd())
106 outIt.Set(m_DefaultLabel);
114 typename DistanceType::Pointer distance = DistanceType::New();
116 while (!outIt.IsAtEnd() && (!inIt.IsAtEnd()))
120 validPoint = maskIt.Get() > 0;
126 LabelType current_label = 1;
129 for (
unsigned int i = 0; i < sampleSize; ++i)
131 pixel[i] = inIt.Get()[i];
134 double current_distance = distance->Evaluate(pixel, m_CentroidsMap[label]);
136 for (label = 2; label <= static_cast<LabelType>(m_CentroidsMap.size()); ++label)
138 double tmp_dist = distance->Evaluate(pixel, m_CentroidsMap[label]);
139 if (tmp_dist < current_distance)
141 current_label = label;
142 current_distance = tmp_dist;
145 outIt.Set(current_label);