OTB  10.0.0
Orfeo Toolbox
otbSOMImageClassificationFilter.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2024 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbSOMImageClassificationFilter_hxx
22 #define otbSOMImageClassificationFilter_hxx
23 
25 #include "itkImageRegionIterator.h"
26 #include "itkNumericTraits.h"
27 
28 namespace otb
29 {
33 template <class TInputImage, class TOutputImage, class TSOMMap, class TMaskImage>
35 {
36  this->SetNumberOfRequiredInputs(2);
37  this->SetNumberOfRequiredInputs(1);
38  m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue();
39 }
41 
42 template <class TInputImage, class TOutputImage, class TSOMMap, class TMaskImage>
44 {
45  this->itk::ProcessObject::SetNthInput(1, const_cast<MaskImageType*>(mask));
46 }
47 
48 template <class TInputImage, class TOutputImage, class TSOMMap, class TMaskImage>
51 {
52  if (this->GetNumberOfInputs() < 2)
53  {
54  return nullptr;
55  }
56  return static_cast<const MaskImageType*>(this->itk::ProcessObject::GetInput(1));
57 }
58 
59 template <class TInputImage, class TOutputImage, class TSOMMap, class TMaskImage>
61 {
62  if (!m_Map)
63  {
64  itkGenericExceptionMacro(<< "No model for classification");
65  }
66 }
67 
68 template <class TInputImage, class TOutputImage, class TSOMMap, class TMaskImage>
70 {
71  InputImageConstPointerType inputPtr = this->GetInput();
72  MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
73  OutputImagePointerType outputPtr = this->GetOutput();
74 
75  typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
76  typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
77  typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
78 
79  ListSamplePointerType listSample = ListSampleType::New();
80  listSample->SetMeasurementVectorSize(inputPtr->GetNumberOfComponentsPerPixel());
81 
82  InputIteratorType inIt(inputPtr, outputRegionForThread);
83 
84  MaskIteratorType maskIt;
85  if (inputMaskPtr)
86  {
87  maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
88  maskIt.GoToBegin();
89  }
90  unsigned int maxDimension = m_Map->GetNumberOfComponentsPerPixel();
91  unsigned int sampleSize = std::min(inputPtr->GetNumberOfComponentsPerPixel(), maxDimension);
92  bool validPoint = true;
93 
94  for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt)
95  {
96  if (inputMaskPtr)
97  {
98  validPoint = maskIt.Get() > 0;
99  ++maskIt;
100  }
101  if (validPoint)
102  {
103  SampleType sample;
104  sample.SetSize(sampleSize);
105  sample.Fill(itk::NumericTraits<ValueType>::ZeroValue());
106  for (unsigned int i = 0; i < sampleSize; ++i)
107  {
108  sample[i] = inIt.Get()[i];
109  }
110  listSample->PushBack(sample);
111  }
112  }
113  ClassifierPointerType classifier = ClassifierType::New();
114  classifier->SetMap(m_Map);
115  classifier->SetSample(listSample);
116  classifier->Update();
117 
118  typename ClassifierType::OutputType::Pointer membershipSample = classifier->GetOutput();
119  typename ClassifierType::OutputType::ConstIterator sampleIter = membershipSample->Begin();
120  typename ClassifierType::OutputType::ConstIterator sampleLast = membershipSample->End();
121 
122  OutputIteratorType outIt(outputPtr, outputRegionForThread);
123 
124  outIt.GoToBegin();
125 
126  while (!outIt.IsAtEnd() && (sampleIter != sampleLast))
127  {
128  outIt.Set(m_DefaultLabel);
129  ++outIt;
130  }
131 
132  outIt.GoToBegin();
133 
134  if (inputMaskPtr)
135  {
136  maskIt.GoToBegin();
137  }
138  validPoint = true;
139 
140  while (!outIt.IsAtEnd() && (sampleIter != sampleLast))
141  {
142  if (inputMaskPtr)
143  {
144  validPoint = maskIt.Get() > 0;
145  ++maskIt;
146  }
147  if (validPoint)
148  {
149  outIt.Set(sampleIter.GetClassLabel());
150  ++sampleIter;
151  }
152  ++outIt;
153  }
154 }
void DynamicThreadedGenerateData(const OutputImageRegionType &outputRegionForThread) override
const MaskImageType * GetInputMask(void)
void SetInputMask(const MaskImageType *mask)
void BeforeThreadedGenerateData() override
std::vector< double > SampleType
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.