22 #ifndef otbSEMClassifier_hxx
23 #define otbSEMClassifier_hxx
28 #include "itkNumericTraits.h"
29 #include "itkImageIterator.h"
30 #include "itkImageRegionIterator.h"
40 template <
class TInputImage,
class TOutputImage>
43 m_TerminationCode = NOT_CONVERGED;
45 m_ComponentDeclared = 0;
48 m_SampleList =
nullptr;
50 m_TerminationThreshold = 1E-5;
53 m_OutputImage =
nullptr;
57 template <
class TInputImage,
class TOutputImage>
60 Superclass::PrintSelf(os, indent);
62 const unsigned int nbClasses = this->GetNumberOfClasses();
64 for (
unsigned int componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
66 os << indent <<
"Component num " << componentIndex;
67 os <<
" (prop " << m_Proportions[componentIndex] <<
") ";
68 m_ComponentVector[componentIndex]->ShowParameters(os, indent);
72 template <
class TInputImage,
class TOutputImage>
75 m_Neighborhood = 2 * (neighborhood / 2) + 1;
76 if (m_Neighborhood < 1)
80 template <
class TInputImage,
class TOutputImage>
83 return m_Neighborhood;
86 template <
class TInputImage,
class TOutputImage>
89 m_InitialProportions = proportions;
93 template <
class TInputImage,
class TOutputImage>
98 m_ClassLabels.resize(labels->Size());
100 ClassLabelVectorType::iterator iterClassLabel = m_ClassLabels.begin();
101 typename OutputType::ConstIterator iterLabels = labels->Begin();
102 typename OutputType::InstanceIdentifier
id = 0;
106 *iterClassLabel = iterLabels->GetClassLabel(
id);
110 }
while (iterLabels != labels->End());
111 m_ExternalLabels = 1;
113 else if (m_ClassLabels.size() == m_NbSamples)
116 m_ClassLabels.resize(labels->Size());
118 ClassLabelVectorType::iterator iterClassLabel = m_ClassLabels.begin();
119 typename OutputType::iterator iterLabels = labels->Begin();
120 typename OutputType::InstanceIdentifier
id = 0;
124 *iterClassLabel = iterLabels->GetClassLabel(
id);
128 }
while (iterLabels != labels->End());
129 m_ExternalLabels = 1;
133 otbMsgDebugMacro(<<
"m_ClassLabels size = " << GetClassLabels().size() <<
" / m_Sample size = " << m_NbSamples);
134 throw itk::ExceptionObject(__FILE__, __LINE__,
"Vector size mismatch", ITK_LOCATION);
138 template <
class TInputImage,
class TOutputImage>
143 typename TInputImage::SizeType size = imgLabels->GetBufferedRegion().GetSize();
144 int theSize = (int)size[0] * size[1];
146 if (m_Sample ==
nullptr)
148 m_ClassLabels.resize(theSize);
150 typename itk::ImageRegionIterator<TOutputImage> imgLabelIter(imgLabels, imgLabels->GetBufferedRegion());
151 imgLabelIter.GoToBegin();
152 typename itk::ImageRegionIterator<TOutputImage> imgLabelIterEnd(imgLabels, imgLabels->GetBufferedRegion());
153 imgLabelIterEnd.GoToEnd();
155 ClassLabelVectorType::iterator iterClassLabel = m_ClassLabels.begin();
159 *iterClassLabel = imgLabelIter.Get();
162 }
while (imgLabelIter != imgLabelIterEnd);
163 m_ExternalLabels = 1;
165 else if (theSize == m_NbSamples)
167 m_ClassLabels.resize(theSize);
169 typename itk::ImageRegionIterator<TOutputImage> imgLabelIter(imgLabels, imgLabels->GetBufferedRegion());
170 imgLabelIter.GoToBegin();
171 typename itk::ImageRegionIterator<TOutputImage> imgLabelIterEnd(imgLabels, imgLabels->GetBufferedRegion());
172 imgLabelIterEnd.GoToEnd();
174 ClassLabelVectorType::iterator iterClassLabel = m_ClassLabels.begin();
178 *iterClassLabel = imgLabelIter.Get();
181 }
while (imgLabelIter != imgLabelIterEnd);
182 m_ExternalLabels = 1;
186 otbMsgDebugMacro(<<
"m_ClassLabels size = " << GetClassLabels().size() <<
" size of the image = " << theSize <<
" / m_Sample size = " << m_NbSamples);
187 throw itk::ExceptionObject(__FILE__, __LINE__,
"Vector size mismatch", ITK_LOCATION);
191 template <
class TInputImage,
class TOutputImage>
194 return m_ClassLabels;
197 template <
class TInputImage,
class TOutputImage>
202 m_SampleList = SampleType::New();
203 m_SampleList->SetMeasurementVectorSize(m_Sample->GetVectorLength());
205 itk::ImageRegionIterator<TInputImage> imgIter((TInputImage*)m_Sample, m_Sample->GetBufferedRegion());
207 itk::ImageRegionIterator<TInputImage> imgIterEnd((TInputImage*)m_Sample, m_Sample->GetBufferedRegion());
208 imgIterEnd.GoToEnd();
212 m_SampleList->PushBack(imgIter.Get());
215 }
while (imgIter != imgIterEnd);
217 if (m_ExternalLabels)
219 typename TInputImage::SizeType size = m_Sample->GetBufferedRegion().GetSize();
220 if ((size[0] * size[1]) != m_ClassLabels.size())
221 throw itk::ExceptionObject(__FILE__, __LINE__,
"Vector size mismatch", ITK_LOCATION);
225 template <
class TInputImage,
class TOutputImage>
231 template <
class TInputImage,
class TOutputImage>
237 template <
class TInputImage,
class TOutputImage>
240 return m_CurrentIteration;
243 template <
class TInputImage,
class TOutputImage>
246 m_ComponentVector[id] = component;
247 m_ComponentDeclared = 1;
249 return static_cast<int>(m_ComponentVector.size());
252 template <
class TInputImage,
class TOutputImage>
255 Superclass::Modified();
256 if (m_ComponentDeclared == 1)
257 otbMsgDebugMacro(<<
"Previous component declarations will be lost since called before SetNumberOfClasses");
258 m_ComponentVector.clear();
259 m_ComponentVector.resize(this->GetNumberOfClasses());
260 m_ComponentDeclared = 0;
263 template <
class TInputImage,
class TOutputImage>
266 return m_OutputImage;
269 template <
class TInputImage,
class TOutputImage>
275 template <
class TInputImage,
class TOutputImage>
278 unsigned int nbClasses = this->GetNumberOfClasses();
280 if (!m_ExternalLabels)
282 m_ClassLabels.resize(m_NbSamples);
283 if (
static_cast<unsigned int>(m_InitialProportions.size()) != nbClasses)
286 for (
typename ClassLabelVectorType::iterator labelIter = m_ClassLabels.begin(); labelIter != m_ClassLabels.end(); ++labelIter)
289 label = rand() % nbClasses;
290 if (label >= nbClasses)
292 label = nbClasses - 1;
300 double sumProportion = 0.0;
301 typename ProportionVectorType::iterator iterProportion = m_InitialProportions.begin();
304 sumProportion += *iterProportion;
305 }
while (++iterProportion != m_InitialProportions.end());
307 if (sumProportion != 1.0)
309 for (iterProportion = m_InitialProportions.begin(); iterProportion != m_InitialProportions.end(); ++iterProportion)
310 *iterProportion /= sumProportion;
315 double cumulativeProportion;
316 for (
typename ClassLabelVectorType::iterator labelIter = m_ClassLabels.begin(); labelIter != m_ClassLabels.end(); ++labelIter)
318 cumulativeProportion = 0.0;
319 sample = double(rand()) / (double(RAND_MAX) + 1.0);
321 *labelIter = nbClasses - 1;
322 for (
unsigned int componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
324 if (cumulativeProportion <= sample && sample < cumulativeProportion + m_InitialProportions[componentIndex])
326 *labelIter = componentIndex;
330 cumulativeProportion += m_InitialProportions[componentIndex];
336 m_Proportions.resize(nbClasses);
337 m_Proba.resize(nbClasses);
338 for (
unsigned int i = 0; i < nbClasses; ++i)
339 m_Proba[i].resize(m_NbSamples);
341 if (!m_ComponentDeclared)
343 otbMsgDebugMacro(<<
"default mixture initialization with " << nbClasses <<
" Gaussian components");
346 for (
unsigned int componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
348 typename GaussianType::Pointer comp = GaussianType::New();
349 AddComponent(componentIndex, comp);
356 otbMsgDevMacro(<<
"contextual neighborhood : " << m_Neighborhood);
357 otbMsgDevMacro(<<
"terminationThreshold : " << m_TerminationThreshold);
360 template <
class TInputImage,
class TOutputImage>
363 unsigned int nbClasses = this->GetNumberOfClasses();
369 for (
typename ClassLabelVectorType::iterator iter = m_ClassLabels.begin(); iter != m_ClassLabels.end(); ++iter)
371 x = double(rand()) / (double(RAND_MAX) + 1.0);
374 for (
unsigned int componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
377 z += m_Proba[componentIndex][posSample];
379 if ((y < x) && (x <= z))
381 if (
static_cast<unsigned int>(componentIndex) != *iter)
384 *iter = componentIndex;
391 switch (GetCurrentIteration())
398 otbMsgDebugMacro(<< m_NbChange <<
" sample change at iteration " << GetCurrentIteration());
402 template <
class TInputImage,
class TOutputImage>
405 unsigned int nbClasses = this->GetNumberOfClasses();
407 unsigned int componentIndex;
408 for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
409 m_Proportions[componentIndex] = 0.0;
411 std::vector<typename ClassSampleType::Pointer> coeffByClass;
413 for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
415 coeffByClass.push_back(ClassSampleType::New());
416 coeffByClass[componentIndex]->SetMeasurementVectorSize(m_SampleList->GetMeasurementVectorSize());
417 coeffByClass[componentIndex]->SetSample(m_SampleList);
420 typename SampleType::ConstIterator iterSample = m_SampleList->Begin();
421 typename SampleType::ConstIterator lastSample = m_SampleList->End();
423 ClassLabelVectorType::iterator iterLabel = m_ClassLabels.begin();
424 ClassLabelVectorType::iterator lastLabel = m_ClassLabels.end();
426 typename SampleType::InstanceIdentifier
id = 0;
430 coeffByClass[*iterLabel]->AddInstance(
id);
431 m_Proportions[*iterLabel] += 1.0;
433 }
while (++iterSample != lastSample && ++iterLabel != lastLabel);
435 for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
437 if (m_Proportions[componentIndex] == 0.0)
439 std::cerr <<
"No sample on class " << componentIndex;
441 std::cerr <<
" in " << ITK_LOCATION << std::endl;
445 m_ComponentVector[componentIndex]->SetSample(coeffByClass[componentIndex]);
446 m_ComponentVector[componentIndex]->Update();
449 for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
450 m_Proportions[componentIndex] /=
static_cast<double>(m_NbSamples);
453 template <
class TInputImage,
class TOutputImage>
456 unsigned int nbClasses = this->GetNumberOfClasses();
459 int voisinage = m_Neighborhood / 2;
460 unsigned int componentIndex;
462 double neighborhoodWeight = (double)m_Neighborhood * m_Neighborhood;
465 typename TInputImage::SizeType size = m_Sample->GetBufferedRegion().GetSize();
469 std::vector<double> pdf(nbClasses);
470 std::vector<double> localWeight(nbClasses);
471 std::vector<double> localCount(nbClasses);
473 typename SampleType::ConstIterator iterSample = m_SampleList->Begin();
474 typename SampleType::ConstIterator lastSample = m_SampleList->End();
477 typename SampleType::InstanceIdentifier
id = 0;
481 id = iterSample.GetInstanceIdentifier();
483 for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
484 localCount[componentIndex] = 0.0;
489 for (a = (i - voisinage); a <= (i + voisinage); a++)
490 for (b = (j - voisinage); b <= (j + voisinage); b++)
492 if (a < 0 || a >= line)
495 if (b < 0 || b >= cols)
498 localCount[m_ClassLabels[a * cols + b]] += 1.0;
501 for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
502 localWeight[componentIndex] = localCount[componentIndex] / neighborhoodWeight;
505 for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
507 measurementVector = iterSample.GetMeasurementVector();
508 aPdf = localWeight[componentIndex] * m_ComponentVector[componentIndex]->Pdf(measurementVector);
510 pdf[componentIndex] = aPdf;
513 for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
516 m_Proba[componentIndex][iterSample.GetInstanceIdentifier()] = 0.0;
518 m_Proba[componentIndex][iterSample.GetInstanceIdentifier()] = pdf[componentIndex] / sumPdf;
521 }
while (++iterSample != lastSample);
524 template <
class TInputImage,
class TOutputImage>
527 unsigned int nbClasses = this->GetNumberOfClasses();
530 m_Output = OutputType::New();
531 m_Output->SetSample(this->GetSampleList());
534 m_Output->SetNumberOfClasses(nbClasses);
537 m_OutputImage = TOutputImage::New();
538 m_OutputImage->SetRegions(GetSample()->GetBufferedRegion());
539 m_OutputImage->Allocate();
542 unsigned int componentIndex;
544 typename SampleType::ConstIterator sampleIter = this->GetSampleList()->Begin();
545 typename SampleType::ConstIterator sampleIterEnd = this->GetSampleList()->End();
547 typename OutputType::ConstIterator outputIter = m_Output->Begin();
548 typename OutputType::ConstIterator outputIterEnd = m_Output->End();
550 typename itk::ImageRegionIterator<TOutputImage> imgOutputIter(m_OutputImage, m_OutputImage->GetBufferedRegion());
551 imgOutputIter.GoToBegin();
552 typename itk::ImageRegionIterator<TOutputImage> imgOutputIterEnd(m_OutputImage, m_OutputImage->GetBufferedRegion());
553 imgOutputIterEnd.GoToEnd();
558 for (componentIndex = 1; componentIndex < nbClasses; ++componentIndex)
560 if (m_Proba[componentIndex][sampleIter.GetInstanceIdentifier()] > m_Proba[cluster][sampleIter.GetInstanceIdentifier()])
561 cluster = componentIndex;
564 m_Output->AddInstance(cluster, sampleIter.GetInstanceIdentifier());
565 imgOutputIter.Set(cluster);
567 }
while (++sampleIter != sampleIterEnd && ++outputIter != outputIterEnd && ++imgOutputIter != imgOutputIterEnd);
570 template <
class TInputImage,
class TOutputImage>
576 m_CurrentIteration = 0;
577 m_TerminationCode = NOT_CONVERGED;
584 oldNbChange = m_NbChange;
586 PerformStochasticProcess();
587 PerformExpectationProcess();
588 PerformMaximizationProcess();
590 step =
static_cast<double>(oldNbChange - m_NbChange);
593 if ((step /
static_cast<double>(m_NbSamples)) < GetTerminationThreshold())
595 m_TerminationCode = CONVERGED;
596 if (oldNbChange != 0)
600 }
while (++m_CurrentIteration < m_MaximumIteration);
602 GetMaximumAposterioriLabels();
TOutputImage * GetOutputImage()
std::vector< double > ProportionVectorType
void SetInitialProportions(ProportionVectorType &proportions)
SampleType * GetSampleList() const
void PerformMaximizationProcess()
void SetClassLabels(OutputType *labels)
int AddComponent(int id, ComponentType *component)
SampleType::MeasurementVectorType MeasurementVectorType
void Modified() const override
itk::Statistics::ListSample< typename TInputImage::PixelType > SampleType
int GetCurrentIteration()
void SetSample(const TInputImage *sample)
void PerformStochasticProcess()
ClassLabelVectorType & GetClassLabels()
void SetNeighborhood(int neighborhood)
std::vector< ClassLabelType > ClassLabelVectorType
void PerformExpectationProcess()
void PrintSelf(std::ostream &os, itk::Indent indent) const override
itk::Statistics::MembershipSample< SampleType > OutputType
void GetMaximumAposterioriLabels()
const TInputImage * GetSample() const
is a component (derived from ModelComponentBase) for Gaussian class. This class is used in SEMClassif...
base class for distribution representation that supports analytical way to update the distribution pa...
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
#define otbMsgDebugMacro(x)
#define otbMsgDevMacro(x)