21 #ifndef otbImageClassificationFilter_hxx
22 #define otbImageClassificationFilter_hxx
25 #include "itkImageRegionIterator.h"
26 #include "itkProgressReporter.h"
33 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
36 this->SetNumberOfIndexedInputs(2);
37 this->SetNumberOfRequiredInputs(1);
38 m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue();
41 this->SetNumberOfRequiredOutputs(3);
42 this->SetNthOutput(0, TOutputImage::New());
43 this->SetNthOutput(1, ConfidenceImageType::New());
44 this->SetNthOutput(2, ProbaImageType::New());
45 m_UseConfidenceMap =
false;
46 m_UseProbaMap =
false;
48 m_NumberOfClasses = 1;
51 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
54 this->itk::ProcessObject::SetNthInput(1,
const_cast<MaskImageType*
>(mask));
57 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
61 if (this->GetNumberOfInputs() < 2)
65 return static_cast<const MaskImageType*
>(this->itk::ProcessObject::GetInput(1));
68 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
72 if (this->GetNumberOfOutputs() < 2)
76 return static_cast<ConfidenceImageType*
>(this->itk::ProcessObject::GetOutput(1));
79 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
83 if (this->GetNumberOfOutputs() < 2)
87 return static_cast<ProbaImageType*
>(this->itk::ProcessObject::GetOutput(2));
90 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
95 itkGenericExceptionMacro(<<
"No model for classification");
101 this->SetNumberOfThreads(1);
106 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
108 itk::ThreadIdType threadId)
111 InputImageConstPointerType inputPtr = this->GetInput();
112 MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
113 OutputImagePointerType outputPtr = this->GetOutput();
114 ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence();
115 ProbaImagePointerType probaPtr = this->GetOutputProba();
117 itk::ProgressReporter progress(
this, threadId, outputRegionForThread.GetNumberOfPixels());
120 typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
121 typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
122 typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
123 typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType;
124 typedef itk::ImageRegionIterator<ProbaImageType> ProbaMapIteratorType;
126 InputIteratorType inIt(inputPtr, outputRegionForThread);
127 OutputIteratorType outIt(outputPtr, outputRegionForThread);
130 MaskIteratorType maskIt;
133 maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
138 bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex() && !m_Model->GetRegressionMode());
139 ConfidenceMapIteratorType confidenceIt;
140 if (computeConfidenceMap)
142 confidenceIt = ConfidenceMapIteratorType(confidencePtr, outputRegionForThread);
143 confidenceIt.GoToBegin();
147 bool computeProbaMap(m_UseProbaMap && m_Model->HasProbaIndex() && !m_Model->GetRegressionMode());
149 ProbaMapIteratorType probaIt;
153 probaIt = ProbaMapIteratorType(probaPtr, outputRegionForThread);
157 bool validPoint =
true;
158 double confidenceIndex = 0.0;
159 ProbaSampleType probaVector{m_NumberOfClasses};
162 for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt)
167 validPoint = maskIt.Get() > 0;
176 outIt.Set(m_Model->Predict(inIt.Get(), &confidenceIndex, &probaVector)[0]);
178 else if (computeConfidenceMap)
180 outIt.Set(m_Model->Predict(inIt.Get(), &confidenceIndex)[0]);
184 outIt.Set(m_Model->Predict(inIt.Get())[0]);
190 outIt.Set(m_DefaultLabel);
191 confidenceIndex = 0.0;
193 if (computeConfidenceMap)
195 confidenceIt.Set(confidenceIndex);
200 probaIt.Set(probaVector);
203 progress.CompletedPixel();
207 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
209 itk::ThreadIdType threadId)
211 bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex() && !m_Model->GetRegressionMode());
213 bool computeProbaMap(m_UseProbaMap && m_Model->HasProbaIndex() && !m_Model->GetRegressionMode());
215 InputImageConstPointerType inputPtr = this->GetInput();
216 MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
217 OutputImagePointerType outputPtr = this->GetOutput();
218 ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence();
219 ProbaImagePointerType probaPtr = this->GetOutputProba();
222 itk::ProgressReporter progress(
this, threadId, outputRegionForThread.GetNumberOfPixels());
225 typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
226 typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
227 typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
228 typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType;
229 typedef itk::ImageRegionIterator<ProbaImageType> ProbaMapIteratorType;
231 InputIteratorType inIt(inputPtr, outputRegionForThread);
232 OutputIteratorType outIt(outputPtr, outputRegionForThread);
234 MaskIteratorType maskIt;
237 maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
241 typedef typename ModelType::InputSampleType InputSampleType;
242 typedef typename ModelType::InputListSampleType InputListSampleType;
243 typedef typename ModelType::TargetValueType TargetValueType;
244 typedef typename ModelType::TargetListSampleType TargetListSampleType;
245 typedef typename ModelType::ConfidenceListSampleType ConfidenceListSampleType;
246 typedef typename ModelType::ProbaListSampleType ProbaListSampleType;
247 typename InputListSampleType::Pointer samples = InputListSampleType::New();
248 unsigned int num_features = inputPtr->GetNumberOfComponentsPerPixel();
249 samples->SetMeasurementVectorSize(num_features);
250 InputSampleType sample(num_features);
252 bool validPoint =
true;
253 for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt)
258 validPoint = maskIt.Get() > 0;
263 typename InputImageType::PixelType pix = inIt.Get();
264 for (
size_t feat = 0; feat < num_features; ++feat)
266 sample[feat] = pix[feat];
268 samples->PushBack(sample);
272 typename TargetListSampleType::Pointer labels;
273 typename ConfidenceListSampleType::Pointer confidences;
274 typename ProbaListSampleType::Pointer probas;
275 if (computeConfidenceMap)
276 confidences = ConfidenceListSampleType::New();
279 probas = ProbaListSampleType::New();
281 labels = m_Model->PredictBatch(samples, confidences, probas);
284 ConfidenceMapIteratorType confidenceIt;
285 if (computeConfidenceMap)
287 confidenceIt = ConfidenceMapIteratorType(confidencePtr, outputRegionForThread);
288 confidenceIt.GoToBegin();
291 ProbaMapIteratorType probaIt;
294 probaIt = ProbaMapIteratorType(probaPtr, outputRegionForThread);
297 typename TargetListSampleType::ConstIterator labIt = labels->Begin();
299 for (outIt.GoToBegin(); !outIt.IsAtEnd(); ++outIt)
301 double confidenceIndex = 0.0;
302 TargetValueType labelValue(m_DefaultLabel);
303 ProbaSampleType probaValues{m_NumberOfClasses};
306 validPoint = maskIt.Get() > 0;
309 if (validPoint && labIt != labels->End())
311 labelValue = labIt.GetMeasurementVector()[0];
313 if (computeConfidenceMap)
315 confidenceIndex = confidences->GetMeasurementVector(labIt.GetInstanceIdentifier())[0];
320 auto tempProbaValues = probas->GetMeasurementVector(labIt.GetInstanceIdentifier());
321 for (
unsigned int i = 0; i < m_NumberOfClasses; ++i)
323 if (i < tempProbaValues.Size())
324 probaValues[i] = tempProbaValues[i];
333 labelValue = m_DefaultLabel;
336 outIt.Set(labelValue);
338 if (computeConfidenceMap)
340 confidenceIt.Set(confidenceIndex);
345 probaIt.Set(probaValues);
348 progress.CompletedPixel();
351 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
353 itk::ThreadIdType threadId)
357 this->BatchThreadedGenerateData(outputRegionForThread, threadId);
361 this->ClassicThreadedGenerateData(outputRegionForThread, threadId);