21 #ifndef otbDSFusionOfClassifiersImageFilter_hxx
22 #define otbDSFusionOfClassifiersImageFilter_hxx
25 #include "itkImageRegionIterator.h"
26 #include "itkProgressReporter.h"
28 #include "itkMetaDataObject.h"
38 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
41 this->SetNumberOfIndexedInputs(2);
42 this->SetNumberOfRequiredInputs(1);
45 this->m_Universe.clear();
46 this->m_LabelForNoDataPixels = itk::NumericTraits<LabelType>::ZeroValue();
47 this->m_LabelForUndecidedPixels = itk::NumericTraits<LabelType>::ZeroValue();
51 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
54 this->itk::ProcessObject::SetNthInput(1,
const_cast<MaskImageType*
>(mask));
57 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
61 if (this->GetNumberOfInputs() < 2)
65 return static_cast<const MaskImageType*
>(this->itk::ProcessObject::GetInput(1));
69 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
71 const VectorOfMapOfMassesOfBeliefType* ptrVectorOfMapOfMassesOfBelief)
73 this->m_VectorOfMapMOBs = *ptrVectorOfMapOfMassesOfBelief;
76 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
80 if (this->GetNumberOfInputs() < 2)
84 return &this->m_VectorOfMapMOBs;
88 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
91 Superclass::GenerateOutputInformation();
94 std::vector<bool> noDataValueAvailable;
95 noDataValueAvailable.push_back(
true);
96 std::vector<double> noDataValue;
97 noDataValue.push_back(m_LabelForNoDataPixels);
99 WriteNoDataFlags(noDataValueAvailable, noDataValue, this->GetOutput()->GetImageMetadata());
103 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
110 m_NumberOfClassifiers = m_VectorOfMapMOBs.size();
115 m_VectorOfUniverseMOBs.clear();
118 for (
unsigned int itClk = 0; itClk < m_NumberOfClassifiers; ++itClk)
121 LabelMassMapType mapMOBsClk = m_VectorOfMapMOBs[itClk];
125 MassType mobUniverseClk = 0.;
127 m_VectorOfUniverseMOBs.push_back(mobUniverseClk);
131 typename LabelMassMapType::iterator itMapMOBClk;
132 for (itMapMOBClk = mapMOBsClk.begin(); itMapMOBClk != mapMOBsClk.end(); ++itMapMOBClk)
134 LabelType classLabel = itMapMOBClk->first;
137 if (m_Universe.count(classLabel) > 0)
139 m_Universe[classLabel]++;
143 m_Universe[classLabel] = 1;
148 m_NumberOfClassesInUniverse = m_Universe.size();
152 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
156 LabelType outFusedLabelOut = itk::NumericTraits<LabelType>::ZeroValue();
162 LabelType classLabelk;
163 MassType mLabelSetClk, mLabelSetClk_, mUniverseClk;
164 MassType mLabelSetClkprev, mLabelSetClkprev_, mUniverseClkprev;
165 MassType KClk, mLabelSetClkNew, mLabelSetClkNew_, mUniverseClkNew;
167 LabelMassMapType mapJointMassesStepI, mapJointMassesStepI_, mapJointMassesUniverseStepI;
168 LabelMassMapType mapOfLabelsWithMassOne, mapOfLabelsWithMassZero;
174 for (
unsigned int itClk = 0; itClk < m_NumberOfClassifiers; ++itClk)
177 classLabelk = vectorPixelValue[itClk];
179 if (classLabelk != m_LabelForNoDataPixels)
183 mUniverseClk = m_VectorOfUniverseMOBs[itClk];
184 mLabelSetClk = m_VectorOfMapMOBs[itClk][classLabelk];
185 mLabelSetClk_ = 1 - mLabelSetClk - mUniverseClk;
192 if ((mLabelSetClk > 0) && (mLabelSetClk < 1.0))
194 if (mapJointMassesStepI.count(classLabelk) == 0)
197 mLabelSetClkNew = mLabelSetClk;
198 mLabelSetClkNew_ = mLabelSetClk_;
199 mUniverseClkNew = mUniverseClk;
204 mLabelSetClkprev = mapJointMassesStepI[classLabelk];
205 mLabelSetClkprev_ = mapJointMassesStepI_[classLabelk];
206 mUniverseClkprev = mapJointMassesUniverseStepI[classLabelk];
208 KClk = 1.0 / (1 - mLabelSetClkprev * mLabelSetClk_ - mLabelSetClkprev_ * mLabelSetClk);
209 mLabelSetClkNew = KClk * (mLabelSetClkprev * (mLabelSetClk + mUniverseClk) + mUniverseClkprev * mLabelSetClk);
210 mLabelSetClkNew_ = KClk * (mLabelSetClkprev_ * (mLabelSetClk_ + mUniverseClk) + mUniverseClkprev * mLabelSetClk_);
211 mUniverseClkNew = KClk * mUniverseClkprev * mUniverseClk;
214 mapJointMassesStepI[classLabelk] = mLabelSetClkNew;
215 mapJointMassesStepI_[classLabelk] = mLabelSetClkNew_;
216 mapJointMassesUniverseStepI[classLabelk] = mUniverseClkNew;
220 if (mLabelSetClk == 1.0)
222 mapOfLabelsWithMassOne[classLabelk] = mLabelSetClk;
226 if (mLabelSetClk == 0)
228 mapOfLabelsWithMassZero[classLabelk] = 1.0;
239 typename LabelMassMapType::iterator itMapMOBClk;
240 if (mapJointMassesStepI.size() == 1)
242 itMapMOBClk = mapJointMassesStepI.begin();
243 classLabelk = itMapMOBClk->first;
258 if (mapOfLabelsWithMassOne.size() > 0)
260 if (mapOfLabelsWithMassOne.size() == 1)
262 itMapMOBClk = mapOfLabelsWithMassOne.begin();
263 classLabelk = itMapMOBClk->first;
268 return m_LabelForUndecidedPixels;
276 if (mapOfLabelsWithMassZero.size() == m_NumberOfClassifiers)
278 return m_LabelForUndecidedPixels;
285 if (mapJointMassesStepI.size() == 0)
287 return m_LabelForNoDataPixels;
300 MassType A = 0, B = 1, C = 1, K;
302 for (itMapMOBClk = mapJointMassesStepI.begin(); itMapMOBClk != mapJointMassesStepI.end(); ++itMapMOBClk)
304 classLabelk = itMapMOBClk->first;
306 mLabelSetClk = mapJointMassesStepI[classLabelk];
307 mLabelSetClk_ = mapJointMassesStepI_[classLabelk];
308 mUniverseClk = mapJointMassesUniverseStepI[classLabelk];
310 A += (mLabelSetClk / (1 - mLabelSetClk));
311 B *= (1 - mLabelSetClk);
320 unsigned int nbClkGroupsStepI = mapJointMassesStepI.size();
321 if (nbClkGroupsStepI == m_NumberOfClassesInUniverse)
323 K = 1.0 / ((1 + A) * B - C);
327 if (nbClkGroupsStepI < m_NumberOfClassesInUniverse)
329 K = 1.0 / ((1 + A) * B);
343 LabelMassMapType mapBelStepII, mapBelStepII_;
344 MassType belLabelSetClk, belLabelSetClk_, addBelLabelSetClk = 0.;
345 for (itMapMOBClk = mapJointMassesStepI.begin(); itMapMOBClk != mapJointMassesStepI.end(); ++itMapMOBClk)
347 classLabelk = itMapMOBClk->first;
350 mLabelSetClk = mapJointMassesStepI[classLabelk];
351 mLabelSetClk_ = mapJointMassesStepI_[classLabelk];
352 mUniverseClk = mapJointMassesUniverseStepI[classLabelk];
355 if ((nbClkGroupsStepI == m_NumberOfClassesInUniverse) || ((nbClkGroupsStepI == (m_NumberOfClassesInUniverse - 1)) && (K == m_NumberOfClassesInUniverse)))
357 belLabelSetClk = K * ((mLabelSetClk / (1 - mLabelSetClk)) * B + (mUniverseClk * C / mLabelSetClk_));
361 belLabelSetClk = K * (mLabelSetClk / (1 - mLabelSetClk)) * B;
365 belLabelSetClk_ = 1 - belLabelSetClk;
382 mapBelStepII[classLabelk] = belLabelSetClk;
383 mapBelStepII_[classLabelk] = belLabelSetClk_;
384 addBelLabelSetClk += belLabelSetClk;
391 typename ClassifierHistogramType::iterator itUniverse;
392 for (itUniverse = m_Universe.begin(); itUniverse != m_Universe.end(); ++itUniverse)
394 classLabelk = itUniverse->first;
398 if (mapBelStepII[classLabelk] == 0)
400 mapBelStepII_[classLabelk] = addBelLabelSetClk;
426 MassType fusedDSBelLabelSetClk = 0.;
427 for (itUniverse = m_Universe.begin(); itUniverse != m_Universe.end(); ++itUniverse)
429 classLabelk = itUniverse->first;
431 if (itUniverse == m_Universe.begin())
433 outFusedLabelOut = classLabelk;
434 fusedDSBelLabelSetClk = mapBelStepII[classLabelk];
438 if (mapBelStepII[classLabelk] >= fusedDSBelLabelSetClk)
440 outFusedLabelOut = classLabelk;
441 fusedDSBelLabelSetClk = mapBelStepII[classLabelk];
447 for (itUniverse = m_Universe.begin(); itUniverse != m_Universe.end(); ++itUniverse)
449 classLabelk = itUniverse->first;
450 if ((mapBelStepII[classLabelk] == fusedDSBelLabelSetClk) && (classLabelk != outFusedLabelOut))
452 outFusedLabelOut = m_LabelForUndecidedPixels;
461 return outFusedLabelOut;
465 template <
class TInputImage,
class TOutputImage,
class TMaskImage>
467 itk::ThreadIdType threadId)
470 InputImageConstPointerType inputPtr = this->GetInput();
471 MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
472 OutputImagePointerType outputPtr = this->GetOutput();
475 itk::ProgressReporter progress(
this, threadId, outputRegionForThread.GetNumberOfPixels());
478 typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
479 typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
480 typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
482 InputIteratorType inIt(inputPtr, outputRegionForThread);
483 OutputIteratorType outIt(outputPtr, outputRegionForThread);
486 MaskIteratorType maskIt;
489 maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
493 bool validPoint =
true;
496 for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt)
501 validPoint = maskIt.Get() > 0;
508 outIt.Set(this->OptimizedDSMassCombination(inIt.Get()));
513 outIt.Set(m_LabelForNoDataPixels);
515 progress.CompletedPixel();