21 #ifndef otbImageClassificationFilter_hxx
22 #define otbImageClassificationFilter_hxx
26 #include "itkImageRegionIterator.h"
27 #include "itkProgressReporter.h"
34 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
37 this->DynamicMultiThreadingOff();
38 this->SetNumberOfIndexedInputs(2);
39 this->SetNumberOfRequiredInputs(1);
40 m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue();
43 this->SetNumberOfRequiredOutputs(3);
44 this->SetNthOutput(0, TOutputImage::New());
45 this->SetNthOutput(1, ConfidenceImageType::New());
46 this->SetNthOutput(2, ProbaImageType::New());
47 m_UseConfidenceMap =
false;
48 m_UseProbaMap =
false;
50 m_NumberOfClasses = 1;
53 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
56 this->itk::ProcessObject::SetNthInput(1,
const_cast<MaskImageType*
>(mask));
59 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
63 if (this->GetNumberOfInputs() < 2)
67 return static_cast<const MaskImageType*
>(this->itk::ProcessObject::GetInput(1));
70 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
74 if (this->GetNumberOfOutputs() < 2)
78 return static_cast<ConfidenceImageType*
>(this->itk::ProcessObject::GetOutput(1));
81 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
85 if (this->GetNumberOfOutputs() < 2)
89 return static_cast<ProbaImageType*
>(this->itk::ProcessObject::GetOutput(2));
92 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
97 itkGenericExceptionMacro(<<
"No model for classification");
103 this->SetNumberOfWorkUnits(1);
108 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
110 itk::ThreadIdType threadId)
113 InputImageConstPointerType inputPtr = this->GetInput();
114 MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
115 OutputImagePointerType outputPtr = this->GetOutput();
116 ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence();
117 ProbaImagePointerType probaPtr = this->GetOutputProba();
119 itk::ProgressReporter progress(
this, threadId, outputRegionForThread.GetNumberOfPixels());
122 typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
123 typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
124 typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
125 typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType;
126 typedef itk::ImageRegionIterator<ProbaImageType> ProbaMapIteratorType;
128 InputIteratorType inIt(inputPtr, outputRegionForThread);
129 OutputIteratorType outIt(outputPtr, outputRegionForThread);
132 MaskIteratorType maskIt;
135 maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
140 bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex() && !m_Model->GetRegressionMode());
141 ConfidenceMapIteratorType confidenceIt;
142 if (computeConfidenceMap)
144 confidenceIt = ConfidenceMapIteratorType(confidencePtr, outputRegionForThread);
145 confidenceIt.GoToBegin();
149 bool computeProbaMap(m_UseProbaMap && m_Model->HasProbaIndex() && !m_Model->GetRegressionMode());
151 ProbaMapIteratorType probaIt;
155 probaIt = ProbaMapIteratorType(probaPtr, outputRegionForThread);
159 bool validPoint =
true;
160 double confidenceIndex = 0.0;
161 ProbaSampleType probaVector{m_NumberOfClasses};
164 for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt)
169 validPoint = maskIt.Get() > 0;
178 outIt.Set(m_Model->Predict(inIt.Get(), &confidenceIndex, &probaVector)[0]);
180 else if (computeConfidenceMap)
182 outIt.Set(m_Model->Predict(inIt.Get(), &confidenceIndex)[0]);
186 outIt.Set(m_Model->Predict(inIt.Get())[0]);
192 outIt.Set(m_DefaultLabel);
193 confidenceIndex = 0.0;
195 if (computeConfidenceMap)
197 confidenceIt.Set(confidenceIndex);
202 probaIt.Set(probaVector);
205 progress.CompletedPixel();
209 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
211 itk::ThreadIdType threadId)
213 bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex() && !m_Model->GetRegressionMode());
215 bool computeProbaMap(m_UseProbaMap && m_Model->HasProbaIndex() && !m_Model->GetRegressionMode());
217 InputImageConstPointerType inputPtr = this->GetInput();
218 MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
219 OutputImagePointerType outputPtr = this->GetOutput();
220 ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence();
221 ProbaImagePointerType probaPtr = this->GetOutputProba();
224 itk::ProgressReporter progress(
this, threadId, outputRegionForThread.GetNumberOfPixels());
227 typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
228 typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
229 typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
230 typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType;
231 typedef itk::ImageRegionIterator<ProbaImageType> ProbaMapIteratorType;
233 InputIteratorType inIt(inputPtr, outputRegionForThread);
234 OutputIteratorType outIt(outputPtr, outputRegionForThread);
236 MaskIteratorType maskIt;
239 maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
243 typedef typename ModelType::InputSampleType InputSampleType;
244 typedef typename ModelType::InputListSampleType InputListSampleType;
245 typedef typename ModelType::TargetValueType TargetValueType;
246 typedef typename ModelType::TargetListSampleType TargetListSampleType;
247 typedef typename ModelType::ConfidenceListSampleType ConfidenceListSampleType;
248 typedef typename ModelType::ProbaListSampleType ProbaListSampleType;
249 typename InputListSampleType::Pointer samples = InputListSampleType::New();
250 unsigned int num_features = inputPtr->GetNumberOfComponentsPerPixel();
251 samples->SetMeasurementVectorSize(num_features);
252 InputSampleType sample(num_features);
254 bool validPoint =
true;
255 for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt)
260 validPoint = maskIt.Get() > 0;
265 typename InputImageType::PixelType pix = inIt.Get();
266 for (
size_t feat = 0; feat < num_features; ++feat)
268 sample[feat] = pix[feat];
270 samples->PushBack(sample);
274 typename TargetListSampleType::Pointer labels;
275 typename ConfidenceListSampleType::Pointer confidences;
276 typename ProbaListSampleType::Pointer probas;
277 if (computeConfidenceMap)
278 confidences = ConfidenceListSampleType::New();
281 probas = ProbaListSampleType::New();
283 labels = m_Model->PredictBatch(samples, confidences, probas);
286 ConfidenceMapIteratorType confidenceIt;
287 if (computeConfidenceMap)
289 confidenceIt = ConfidenceMapIteratorType(confidencePtr, outputRegionForThread);
290 confidenceIt.GoToBegin();
293 ProbaMapIteratorType probaIt;
296 probaIt = ProbaMapIteratorType(probaPtr, outputRegionForThread);
299 typename TargetListSampleType::ConstIterator labIt = labels->Begin();
301 for (outIt.GoToBegin(); !outIt.IsAtEnd(); ++outIt)
303 double confidenceIndex = 0.0;
304 TargetValueType labelValue(m_DefaultLabel);
305 ProbaSampleType probaValues{m_NumberOfClasses};
308 validPoint = maskIt.Get() > 0;
311 if (validPoint && labIt != labels->End())
313 labelValue = labIt.GetMeasurementVector()[0];
315 if (computeConfidenceMap)
317 confidenceIndex = confidences->GetMeasurementVector(labIt.GetInstanceIdentifier())[0];
322 auto tempProbaValues = probas->GetMeasurementVector(labIt.GetInstanceIdentifier());
323 for (
unsigned int i = 0; i < m_NumberOfClasses; ++i)
325 if (i < tempProbaValues.Size())
326 probaValues[i] = tempProbaValues[i];
335 labelValue = m_DefaultLabel;
338 outIt.Set(labelValue);
340 if (computeConfidenceMap)
342 confidenceIt.Set(confidenceIndex);
347 probaIt.Set(probaValues);
350 progress.CompletedPixel();
353 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
355 itk::ThreadIdType threadId)
359 this->BatchThreadedGenerateData(outputRegionForThread, threadId);
363 this->ClassicThreadedGenerateData(outputRegionForThread, threadId);
void BatchThreadedGenerateData(const OutputImageRegionType &outputRegionForThread, itk::ThreadIdType threadId)
otb::Image< double > ConfidenceImageType
void SetInputMask(const MaskImageType *mask)
ImageClassificationFilter()
const MaskImageType * GetInputMask(void)
void ClassicThreadedGenerateData(const OutputImageRegionType &outputRegionForThread, itk::ThreadIdType threadId)
otb::VectorImage< double > ProbaImageType
void ThreadedGenerateData(const OutputImageRegionType &outputRegionForThread, itk::ThreadIdType threadId) override
ConfidenceImageType * GetOutputConfidence(void)
ProbaImageType * GetOutputProba(void)
void BeforeThreadedGenerateData() override
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.