OTB  10.0.0
Orfeo Toolbox
otbImageClassificationFilter.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2024 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbImageClassificationFilter_hxx
22 #define otbImageClassificationFilter_hxx
23 
25 #include "otbMacro.h" //for
26 #include "itkImageRegionIterator.h"
27 #include "itkProgressReporter.h"
28 
29 namespace otb
30 {
34 template <class TInputImage, class TOutputImage, class TMaskImage>
36 {
37  this->DynamicMultiThreadingOff();
38  this->SetNumberOfIndexedInputs(2);
39  this->SetNumberOfRequiredInputs(1);
40  m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue();
42 
43  this->SetNumberOfRequiredOutputs(3);
44  this->SetNthOutput(0, TOutputImage::New());
45  this->SetNthOutput(1, ConfidenceImageType::New());
46  this->SetNthOutput(2, ProbaImageType::New());
47  m_UseConfidenceMap = false;
48  m_UseProbaMap = false;
49  m_BatchMode = true;
50  m_NumberOfClasses = 1;
51 }
52 
53 template <class TInputImage, class TOutputImage, class TMaskImage>
55 {
56  this->itk::ProcessObject::SetNthInput(1, const_cast<MaskImageType*>(mask));
57 }
58 
59 template <class TInputImage, class TOutputImage, class TMaskImage>
62 {
63  if (this->GetNumberOfInputs() < 2)
64  {
65  return nullptr;
66  }
67  return static_cast<const MaskImageType*>(this->itk::ProcessObject::GetInput(1));
68 }
69 
70 template <class TInputImage, class TOutputImage, class TMaskImage>
73 {
74  if (this->GetNumberOfOutputs() < 2)
75  {
76  return nullptr;
77  }
78  return static_cast<ConfidenceImageType*>(this->itk::ProcessObject::GetOutput(1));
79 }
80 
81 template <class TInputImage, class TOutputImage, class TMaskImage>
84 {
85  if (this->GetNumberOfOutputs() < 2)
86  {
87  return nullptr;
88  }
89  return static_cast<ProbaImageType*>(this->itk::ProcessObject::GetOutput(2));
90 }
91 
92 template <class TInputImage, class TOutputImage, class TMaskImage>
94 {
95  if (!m_Model)
96  {
97  itkGenericExceptionMacro(<< "No model for classification");
98  }
99  if (m_BatchMode)
100  {
101 #ifdef _OPENMP
102  // OpenMP will take care of threading
103  this->SetNumberOfWorkUnits(1);
104 #endif
105  }
106 }
107 
108 template <class TInputImage, class TOutputImage, class TMaskImage>
109 void ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>::ClassicThreadedGenerateData(const OutputImageRegionType& outputRegionForThread,
110  itk::ThreadIdType threadId)
111 {
112  // Get the input pointers
113  InputImageConstPointerType inputPtr = this->GetInput();
114  MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
115  OutputImagePointerType outputPtr = this->GetOutput();
116  ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence();
117  ProbaImagePointerType probaPtr = this->GetOutputProba();
118  // Progress reporting
119  itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
120 
121  // Define iterators
122  typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
123  typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
124  typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
125  typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType;
126  typedef itk::ImageRegionIterator<ProbaImageType> ProbaMapIteratorType;
127 
128  InputIteratorType inIt(inputPtr, outputRegionForThread);
129  OutputIteratorType outIt(outputPtr, outputRegionForThread);
130 
131  // Eventually iterate on masks
132  MaskIteratorType maskIt;
133  if (inputMaskPtr)
134  {
135  maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
136  maskIt.GoToBegin();
137  }
138 
139  // setup iterator for confidence map
140  bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex() && !m_Model->GetRegressionMode());
141  ConfidenceMapIteratorType confidenceIt;
142  if (computeConfidenceMap)
143  {
144  confidenceIt = ConfidenceMapIteratorType(confidencePtr, outputRegionForThread);
145  confidenceIt.GoToBegin();
146  }
147 
148  // setup iterator for proba map
149  bool computeProbaMap(m_UseProbaMap && m_Model->HasProbaIndex() && !m_Model->GetRegressionMode());
150 
151  ProbaMapIteratorType probaIt;
152 
153  if (computeProbaMap)
154  {
155  probaIt = ProbaMapIteratorType(probaPtr, outputRegionForThread);
156  probaIt.GoToBegin();
157  }
158 
159  bool validPoint = true;
160  double confidenceIndex = 0.0;
161  ProbaSampleType probaVector{m_NumberOfClasses};
162  probaVector.Fill(0);
163  // Walk the part of the image
164  for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt)
165  {
166  // Check pixel validity
167  if (inputMaskPtr)
168  {
169  validPoint = maskIt.Get() > 0;
170  ++maskIt;
171  }
172  // If point is valid
173  if (validPoint)
174  {
175  // Classifify
176  if (computeProbaMap)
177  {
178  outIt.Set(m_Model->Predict(inIt.Get(), &confidenceIndex, &probaVector)[0]);
179  }
180  else if (computeConfidenceMap)
181  {
182  outIt.Set(m_Model->Predict(inIt.Get(), &confidenceIndex)[0]);
183  }
184  else
185  {
186  outIt.Set(m_Model->Predict(inIt.Get())[0]);
187  }
188  }
189  else
190  {
191  // else, set default value
192  outIt.Set(m_DefaultLabel);
193  confidenceIndex = 0.0;
194  }
195  if (computeConfidenceMap)
196  {
197  confidenceIt.Set(confidenceIndex);
198  ++confidenceIt;
199  }
200  if (computeProbaMap)
201  {
202  probaIt.Set(probaVector);
203  ++probaIt;
204  }
205  progress.CompletedPixel();
206  }
207 }
208 
209 template <class TInputImage, class TOutputImage, class TMaskImage>
210 void ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>::BatchThreadedGenerateData(const OutputImageRegionType& outputRegionForThread,
211  itk::ThreadIdType threadId)
212 {
213  bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex() && !m_Model->GetRegressionMode());
214 
215  bool computeProbaMap(m_UseProbaMap && m_Model->HasProbaIndex() && !m_Model->GetRegressionMode());
216  // Get the input pointers
217  InputImageConstPointerType inputPtr = this->GetInput();
218  MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
219  OutputImagePointerType outputPtr = this->GetOutput();
220  ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence();
221  ProbaImagePointerType probaPtr = this->GetOutputProba();
222 
223  // Progress reporting
224  itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
225 
226  // Define iterators
227  typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
228  typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
229  typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
230  typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType;
231  typedef itk::ImageRegionIterator<ProbaImageType> ProbaMapIteratorType;
232 
233  InputIteratorType inIt(inputPtr, outputRegionForThread);
234  OutputIteratorType outIt(outputPtr, outputRegionForThread);
235 
236  MaskIteratorType maskIt;
237  if (inputMaskPtr)
238  {
239  maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
240  maskIt.GoToBegin();
241  }
242 
243  typedef typename ModelType::InputSampleType InputSampleType;
244  typedef typename ModelType::InputListSampleType InputListSampleType;
245  typedef typename ModelType::TargetValueType TargetValueType;
246  typedef typename ModelType::TargetListSampleType TargetListSampleType;
247  typedef typename ModelType::ConfidenceListSampleType ConfidenceListSampleType;
248  typedef typename ModelType::ProbaListSampleType ProbaListSampleType;
249  typename InputListSampleType::Pointer samples = InputListSampleType::New();
250  unsigned int num_features = inputPtr->GetNumberOfComponentsPerPixel();
251  samples->SetMeasurementVectorSize(num_features);
252  InputSampleType sample(num_features);
253  // Fill the samples
254  bool validPoint = true;
255  for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt)
256  {
257  // Check pixel validity
258  if (inputMaskPtr)
259  {
260  validPoint = maskIt.Get() > 0;
261  ++maskIt;
262  }
263  if (validPoint)
264  {
265  typename InputImageType::PixelType pix = inIt.Get();
266  for (size_t feat = 0; feat < num_features; ++feat)
267  {
268  sample[feat] = pix[feat];
269  }
270  samples->PushBack(sample);
271  }
272  }
273  // Make the batch prediction
274  typename TargetListSampleType::Pointer labels;
275  typename ConfidenceListSampleType::Pointer confidences;
276  typename ProbaListSampleType::Pointer probas;
277  if (computeConfidenceMap)
278  confidences = ConfidenceListSampleType::New();
279 
280  if (computeProbaMap)
281  probas = ProbaListSampleType::New();
282  // This call is threadsafe
283  labels = m_Model->PredictBatch(samples, confidences, probas);
284 
285  // Set the output values
286  ConfidenceMapIteratorType confidenceIt;
287  if (computeConfidenceMap)
288  {
289  confidenceIt = ConfidenceMapIteratorType(confidencePtr, outputRegionForThread);
290  confidenceIt.GoToBegin();
291  }
292 
293  ProbaMapIteratorType probaIt;
294  if (computeProbaMap)
295  {
296  probaIt = ProbaMapIteratorType(probaPtr, outputRegionForThread);
297  probaIt.GoToBegin();
298  }
299  typename TargetListSampleType::ConstIterator labIt = labels->Begin();
300  maskIt.GoToBegin();
301  for (outIt.GoToBegin(); !outIt.IsAtEnd(); ++outIt)
302  {
303  double confidenceIndex = 0.0;
304  TargetValueType labelValue(m_DefaultLabel);
305  ProbaSampleType probaValues{m_NumberOfClasses};
306  if (inputMaskPtr)
307  {
308  validPoint = maskIt.Get() > 0;
309  ++maskIt;
310  }
311  if (validPoint && labIt != labels->End())
312  {
313  labelValue = labIt.GetMeasurementVector()[0];
314 
315  if (computeConfidenceMap)
316  {
317  confidenceIndex = confidences->GetMeasurementVector(labIt.GetInstanceIdentifier())[0];
318  }
319  if (computeProbaMap)
320  {
321  // The probas may have different size than the m_NumberOfClasses set by the user
322  auto tempProbaValues = probas->GetMeasurementVector(labIt.GetInstanceIdentifier());
323  for (unsigned int i = 0; i < m_NumberOfClasses; ++i)
324  {
325  if (i < tempProbaValues.Size())
326  probaValues[i] = tempProbaValues[i];
327  else
328  probaValues[i] = 0;
329  }
330  }
331  ++labIt;
332  }
333  else
334  {
335  labelValue = m_DefaultLabel;
336  }
337 
338  outIt.Set(labelValue);
339 
340  if (computeConfidenceMap)
341  {
342  confidenceIt.Set(confidenceIndex);
343  ++confidenceIt;
344  }
345  if (computeProbaMap)
346  {
347  probaIt.Set(probaValues);
348  ++probaIt;
349  }
350  progress.CompletedPixel();
351  }
352 }
353 template <class TInputImage, class TOutputImage, class TMaskImage>
354 void ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread,
355  itk::ThreadIdType threadId)
356 {
357  if (m_BatchMode)
358  {
359  this->BatchThreadedGenerateData(outputRegionForThread, threadId);
360  }
361  else
362  {
363  this->ClassicThreadedGenerateData(outputRegionForThread, threadId);
364  }
365 }
void BatchThreadedGenerateData(const OutputImageRegionType &outputRegionForThread, itk::ThreadIdType threadId)
void SetInputMask(const MaskImageType *mask)
const MaskImageType * GetInputMask(void)
void ClassicThreadedGenerateData(const OutputImageRegionType &outputRegionForThread, itk::ThreadIdType threadId)
otb::VectorImage< double > ProbaImageType
void ThreadedGenerateData(const OutputImageRegionType &outputRegionForThread, itk::ThreadIdType threadId) override
ConfidenceImageType * GetOutputConfidence(void)
ProbaImageType * GetOutputProba(void)
void BeforeThreadedGenerateData() override
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.