OTB  10.0.0
Orfeo Toolbox
otbKMeansImageClassificationFilter.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2024 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbKMeansImageClassificationFilter_hxx
22 #define otbKMeansImageClassificationFilter_hxx
23 
24 #include "otbMacro.h" //for
26 #include "itkImageRegionIterator.h"
27 #include "itkNumericTraits.h"
28 
29 namespace otb
30 {
34 template <class TInputImage, class TOutputImage, unsigned int VMaxSampleDimension, class TMaskImage>
36 {
37  this->SetNumberOfRequiredInputs(2);
38  this->SetNumberOfRequiredInputs(1);
39  m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue();
40 }
42 
43 template <class TInputImage, class TOutputImage, unsigned int VMaxSampleDimension, class TMaskImage>
45 {
46  this->itk::ProcessObject::SetNthInput(1, const_cast<MaskImageType*>(mask));
47 }
48 
49 template <class TInputImage, class TOutputImage, unsigned int VMaxSampleDimension, class TMaskImage>
52 {
53  if (this->GetNumberOfInputs() < 2)
54  {
55  return nullptr;
56  }
57  return static_cast<const MaskImageType*>(this->itk::ProcessObject::GetInput(1));
58 }
59 
60 template <class TInputImage, class TOutputImage, unsigned int VMaxSampleDimension, class TMaskImage>
62 {
63  unsigned int sample_size = MaxSampleDimension;
64  unsigned int nb_classes = m_Centroids.Size() / sample_size;
65 
66  for (LabelType label = 1; label <= static_cast<LabelType>(nb_classes); ++label)
67  {
68  SampleType new_centroid;
69  new_centroid.Fill(0);
70  m_CentroidsMap[label] = new_centroid;
71 
72  for (unsigned int i = 0; i < MaxSampleDimension; ++i)
73  {
74  m_CentroidsMap[label][i] = static_cast<ValueType>(m_Centroids[MaxSampleDimension * (static_cast<unsigned int>(label) - 1) + i]);
75  }
76  }
77 }
78 
79 template <class TInputImage, class TOutputImage, unsigned int VMaxSampleDimension, class TMaskImage>
81  const OutputImageRegionType& outputRegionForThread)
82 {
83  InputImageConstPointerType inputPtr = this->GetInput();
84  MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
85  OutputImagePointerType outputPtr = this->GetOutput();
86 
87  typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
88  typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
89  typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
90 
91  InputIteratorType inIt(inputPtr, outputRegionForThread);
92  OutputIteratorType outIt(outputPtr, outputRegionForThread);
93 
94  MaskIteratorType maskIt;
95  if (inputMaskPtr)
96  {
97  maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
98  maskIt.GoToBegin();
99  }
100  unsigned int maxDimension = SampleType::Dimension;
101  unsigned int sampleSize = std::min(inputPtr->GetNumberOfComponentsPerPixel(), maxDimension);
102 
103  bool validPoint = true;
104 
105  while (!outIt.IsAtEnd())
106  {
107  outIt.Set(m_DefaultLabel);
108  ++outIt;
109  }
110 
111  outIt.GoToBegin();
112 
113  validPoint = true;
114 
115  typename DistanceType::Pointer distance = DistanceType::New();
116 
117  while (!outIt.IsAtEnd() && (!inIt.IsAtEnd()))
118  {
119  if (inputMaskPtr)
120  {
121  validPoint = maskIt.Get() > 0;
122  ++maskIt;
123  }
124  if (validPoint)
125  {
126  LabelType label = 1;
127  LabelType current_label = 1;
128  SampleType pixel;
129  pixel.Fill(0);
130  for (unsigned int i = 0; i < sampleSize; ++i)
131  {
132  pixel[i] = inIt.Get()[i];
133  }
134 
135  double current_distance = distance->Evaluate(pixel, m_CentroidsMap[label]);
136 
137  for (label = 2; label <= static_cast<LabelType>(m_CentroidsMap.size()); ++label)
138  {
139  double tmp_dist = distance->Evaluate(pixel, m_CentroidsMap[label]);
140  if (tmp_dist < current_distance)
141  {
142  current_label = label;
143  current_distance = tmp_dist;
144  }
145  }
146  outIt.Set(current_label);
147  }
148  ++outIt;
149  ++inIt;
150  }
151 }
void SetInputMask(const MaskImageType *mask)
void DynamicThreadedGenerateData(const OutputImageRegionType &outputRegionForThread) override
const MaskImageType * GetInputMask(void)
std::vector< double > SampleType
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.