21 #ifndef otbSOMImageClassificationFilter_hxx
22 #define otbSOMImageClassificationFilter_hxx
25 #include "itkImageRegionIterator.h"
26 #include "itkNumericTraits.h"
33 template <
class TInputImage,
class TOutputImage,
class TSOMMap,
class TMaskImage>
36 this->SetNumberOfRequiredInputs(2);
37 this->SetNumberOfRequiredInputs(1);
38 m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue();
42 template <
class TInputImage,
class TOutputImage,
class TSOMMap,
class TMaskImage>
45 this->itk::ProcessObject::SetNthInput(1,
const_cast<MaskImageType*
>(mask));
48 template <
class TInputImage,
class TOutputImage,
class TSOMMap,
class TMaskImage>
52 if (this->GetNumberOfInputs() < 2)
56 return static_cast<const MaskImageType*
>(this->itk::ProcessObject::GetInput(1));
59 template <
class TInputImage,
class TOutputImage,
class TSOMMap,
class TMaskImage>
64 itkGenericExceptionMacro(<<
"No model for classification");
68 template <
class TInputImage,
class TOutputImage,
class TSOMMap,
class TMaskImage>
70 itk::ThreadIdType itkNotUsed(threadId))
72 InputImageConstPointerType inputPtr = this->GetInput();
73 MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
74 OutputImagePointerType outputPtr = this->GetOutput();
76 typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
77 typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
78 typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
80 ListSamplePointerType listSample = ListSampleType::New();
81 listSample->SetMeasurementVectorSize(inputPtr->GetNumberOfComponentsPerPixel());
83 InputIteratorType inIt(inputPtr, outputRegionForThread);
85 MaskIteratorType maskIt;
88 maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
91 unsigned int maxDimension = m_Map->GetNumberOfComponentsPerPixel();
92 unsigned int sampleSize = std::min(inputPtr->GetNumberOfComponentsPerPixel(), maxDimension);
93 bool validPoint =
true;
95 for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt)
99 validPoint = maskIt.Get() > 0;
105 sample.SetSize(sampleSize);
106 sample.Fill(itk::NumericTraits<ValueType>::ZeroValue());
107 for (
unsigned int i = 0; i < sampleSize; ++i)
109 sample[i] = inIt.Get()[i];
111 listSample->PushBack(sample);
114 ClassifierPointerType classifier = ClassifierType::New();
115 classifier->SetMap(m_Map);
116 classifier->SetSample(listSample);
117 classifier->Update();
119 typename ClassifierType::OutputType::Pointer membershipSample = classifier->GetOutput();
120 typename ClassifierType::OutputType::ConstIterator sampleIter = membershipSample->Begin();
121 typename ClassifierType::OutputType::ConstIterator sampleLast = membershipSample->End();
123 OutputIteratorType outIt(outputPtr, outputRegionForThread);
127 while (!outIt.IsAtEnd() && (sampleIter != sampleLast))
129 outIt.Set(m_DefaultLabel);
141 while (!outIt.IsAtEnd() && (sampleIter != sampleLast))
145 validPoint = maskIt.Get() > 0;
150 outIt.Set(sampleIter.GetClassLabel());