22 #ifndef otbSEMClassifier_hxx
23 #define otbSEMClassifier_hxx
28 #include "itkNumericTraits.h"
29 #include "itkImageIterator.h"
30 #include "itkImageRegionIterator.h"
40 template <
class TInputImage,
class TOutputImage>
43 m_TerminationCode = NOT_CONVERGED;
45 m_ComponentDeclared = 0;
48 m_SampleList =
nullptr;
50 m_TerminationThreshold = 1E-5;
53 m_OutputImage =
nullptr;
57 template <
class TInputImage,
class TOutputImage>
60 Superclass::PrintSelf(os, indent);
62 const unsigned int nbClasses = this->GetNumberOfClasses();
64 for (
unsigned int componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
66 os << indent <<
"Component num " << componentIndex;
67 os <<
" (prop " << m_Proportions[componentIndex] <<
") ";
68 m_ComponentVector[componentIndex]->ShowParameters(os, indent);
72 template <
class TInputImage,
class TOutputImage>
75 m_Neighborhood = 2 * (neighborhood / 2) + 1;
76 if (m_Neighborhood < 1)
80 template <
class TInputImage,
class TOutputImage>
83 return m_Neighborhood;
86 template <
class TInputImage,
class TOutputImage>
89 m_InitialProportions = proportions;
93 template <
class TInputImage,
class TOutputImage>
98 m_ClassLabels.resize(labels->Size());
100 ClassLabelVectorType::iterator iterClassLabel = m_ClassLabels.begin();
101 typename OutputType::ConstIterator iterLabels = labels->Begin();
102 typename OutputType::InstanceIdentifier
id = 0;
106 *iterClassLabel = iterLabels->GetClassLabel(
id);
110 }
while (iterLabels != labels->End());
111 m_ExternalLabels = 1;
113 else if (m_ClassLabels.size() == m_NbSamples)
116 m_ClassLabels.resize(labels->Size());
118 ClassLabelVectorType::iterator iterClassLabel = m_ClassLabels.begin();
119 typename OutputType::iterator iterLabels = labels->Begin();
120 typename OutputType::InstanceIdentifier
id = 0;
124 *iterClassLabel = iterLabels->GetClassLabel(
id);
128 }
while (iterLabels != labels->End());
129 m_ExternalLabels = 1;
133 otbMsgDebugMacro(<<
"m_ClassLabels size = " << GetClassLabels().size() <<
" / m_Sample size = " << m_NbSamples);
134 throw itk::ExceptionObject(__FILE__, __LINE__,
"Vector size mismatch", ITK_LOCATION);
138 template <
class TInputImage,
class TOutputImage>
143 typename TInputImage::SizeType size = imgLabels->GetBufferedRegion().GetSize();
144 int theSize = (int)size[0] * size[1];
146 if (m_Sample ==
nullptr)
148 m_ClassLabels.resize(theSize);
150 typename itk::ImageRegionIterator<TOutputImage> imgLabelIter(imgLabels, imgLabels->GetBufferedRegion());
151 imgLabelIter.GoToBegin();
152 typename itk::ImageRegionIterator<TOutputImage> imgLabelIterEnd(imgLabels, imgLabels->GetBufferedRegion());
153 imgLabelIterEnd.GoToEnd();
155 ClassLabelVectorType::iterator iterClassLabel = m_ClassLabels.begin();
159 *iterClassLabel = imgLabelIter.Get();
162 }
while (imgLabelIter != imgLabelIterEnd);
163 m_ExternalLabels = 1;
165 else if (theSize == m_NbSamples)
167 m_ClassLabels.resize(theSize);
169 typename itk::ImageRegionIterator<TOutputImage> imgLabelIter(imgLabels, imgLabels->GetBufferedRegion());
170 imgLabelIter.GoToBegin();
171 typename itk::ImageRegionIterator<TOutputImage> imgLabelIterEnd(imgLabels, imgLabels->GetBufferedRegion());
172 imgLabelIterEnd.GoToEnd();
174 ClassLabelVectorType::iterator iterClassLabel = m_ClassLabels.begin();
178 *iterClassLabel = imgLabelIter.Get();
181 }
while (imgLabelIter != imgLabelIterEnd);
182 m_ExternalLabels = 1;
186 otbMsgDebugMacro(<<
"m_ClassLabels size = " << GetClassLabels().size() <<
" size of the image = " << theSize <<
" / m_Sample size = " << m_NbSamples);
187 throw itk::ExceptionObject(__FILE__, __LINE__,
"Vector size mismatch", ITK_LOCATION);
191 template <
class TInputImage,
class TOutputImage>
194 return m_ClassLabels;
197 template <
class TInputImage,
class TOutputImage>
202 m_SampleList = SampleType::New();
203 m_SampleList->SetMeasurementVectorSize(m_Sample->GetVectorLength());
205 itk::ImageRegionIterator<TInputImage> imgIter((TInputImage*)m_Sample, m_Sample->GetBufferedRegion());
207 itk::ImageRegionIterator<TInputImage> imgIterEnd((TInputImage*)m_Sample, m_Sample->GetBufferedRegion());
208 imgIterEnd.GoToEnd();
212 m_SampleList->PushBack(imgIter.Get());
215 }
while (imgIter != imgIterEnd);
217 if (m_ExternalLabels)
219 typename TInputImage::SizeType size = m_Sample->GetBufferedRegion().GetSize();
220 if ((size[0] * size[1]) != m_ClassLabels.size())
221 throw itk::ExceptionObject(__FILE__, __LINE__,
"Vector size mismatch", ITK_LOCATION);
225 template <
class TInputImage,
class TOutputImage>
231 template <
class TInputImage,
class TOutputImage>
237 template <
class TInputImage,
class TOutputImage>
240 return m_CurrentIteration;
243 template <
class TInputImage,
class TOutputImage>
246 m_ComponentVector[id] = component;
247 m_ComponentDeclared = 1;
249 return static_cast<int>(m_ComponentVector.size());
252 template <
class TInputImage,
class TOutputImage>
255 Superclass::Modified();
256 if (m_ComponentDeclared == 1)
257 otbMsgDebugMacro(<<
"Previous component declarations will be lost since called before SetNumberOfClasses");
258 m_ComponentVector.clear();
259 m_ComponentVector.resize(this->GetNumberOfClasses());
260 m_ComponentDeclared = 0;
263 template <
class TInputImage,
class TOutputImage>
266 return m_OutputImage;
269 template <
class TInputImage,
class TOutputImage>
275 template <
class TInputImage,
class TOutputImage>
278 unsigned int nbClasses = this->GetNumberOfClasses();
280 if (!m_ExternalLabels)
282 m_ClassLabels.resize(m_NbSamples);
283 if (
static_cast<unsigned int>(m_InitialProportions.size()) != nbClasses)
286 for (
typename ClassLabelVectorType::iterator labelIter = m_ClassLabels.begin(); labelIter != m_ClassLabels.end(); ++labelIter)
289 label = rand() % nbClasses;
290 if (label >= nbClasses)
292 label = nbClasses - 1;
300 double sumProportion = 0.0;
301 typename ProportionVectorType::iterator iterProportion = m_InitialProportions.begin();
304 sumProportion += *iterProportion;
305 }
while (++iterProportion != m_InitialProportions.end());
307 if (sumProportion != 1.0)
309 for (iterProportion = m_InitialProportions.begin(); iterProportion != m_InitialProportions.end(); ++iterProportion)
310 *iterProportion /= sumProportion;
315 double cumulativeProportion;
316 for (
typename ClassLabelVectorType::iterator labelIter = m_ClassLabels.begin(); labelIter != m_ClassLabels.end(); ++labelIter)
318 cumulativeProportion = 0.0;
319 sample = double(rand()) / (double(RAND_MAX) + 1.0);
321 *labelIter = nbClasses - 1;
322 for (
unsigned int componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
324 if (cumulativeProportion <= sample && sample < cumulativeProportion + m_InitialProportions[componentIndex])
326 *labelIter = componentIndex;
330 cumulativeProportion += m_InitialProportions[componentIndex];
336 m_Proportions.resize(nbClasses);
337 m_Proba.resize(nbClasses);
338 for (
unsigned int i = 0; i < nbClasses; ++i)
339 m_Proba[i].resize(m_NbSamples);
341 if (!m_ComponentDeclared)
343 otbMsgDebugMacro(<<
"default mixture initialization with " << nbClasses <<
" Gaussian components");
346 for (
unsigned int componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
348 typename GaussianType::Pointer comp = GaussianType::New();
349 AddComponent(componentIndex, comp);
356 otbMsgDevMacro(<<
"contextual neighborhood : " << m_Neighborhood);
357 otbMsgDevMacro(<<
"terminationThreshold : " << m_TerminationThreshold);
360 template <
class TInputImage,
class TOutputImage>
363 unsigned int nbClasses = this->GetNumberOfClasses();
369 for (
typename ClassLabelVectorType::iterator iter = m_ClassLabels.begin(); iter != m_ClassLabels.end(); ++iter)
371 x = double(rand()) / (double(RAND_MAX) + 1.0);
374 for (
unsigned int componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
377 z += m_Proba[componentIndex][posSample];
379 if ((y < x) && (x <= z))
381 if (
static_cast<unsigned int>(componentIndex) != *iter)
384 *iter = componentIndex;
391 switch (GetCurrentIteration())
398 otbMsgDebugMacro(<< m_NbChange <<
" sample change at iteration " << GetCurrentIteration());
402 template <
class TInputImage,
class TOutputImage>
405 unsigned int nbClasses = this->GetNumberOfClasses();
407 unsigned int componentIndex;
408 for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
409 m_Proportions[componentIndex] = 0.0;
411 std::vector<typename ClassSampleType::Pointer> coeffByClass;
413 for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
415 coeffByClass.push_back(ClassSampleType::New());
416 coeffByClass[componentIndex]->SetMeasurementVectorSize(m_SampleList->GetMeasurementVectorSize());
417 coeffByClass[componentIndex]->SetSample(m_SampleList);
420 typename SampleType::ConstIterator iterSample = m_SampleList->Begin();
421 typename SampleType::ConstIterator lastSample = m_SampleList->End();
423 ClassLabelVectorType::iterator iterLabel = m_ClassLabels.begin();
424 ClassLabelVectorType::iterator lastLabel = m_ClassLabels.end();
426 typename SampleType::InstanceIdentifier
id = 0;
430 coeffByClass[*iterLabel]->AddInstance(
id);
431 m_Proportions[*iterLabel] += 1.0;
433 }
while (++iterSample != lastSample && ++iterLabel != lastLabel);
435 for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
437 if (m_Proportions[componentIndex] == 0.0)
439 std::cerr <<
"No sample on class " << componentIndex;
441 std::cerr <<
" in " << ITK_LOCATION << std::endl;
445 m_ComponentVector[componentIndex]->SetSample(coeffByClass[componentIndex]);
446 m_ComponentVector[componentIndex]->Update();
449 for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
450 m_Proportions[componentIndex] /=
static_cast<double>(m_NbSamples);
453 template <
class TInputImage,
class TOutputImage>
456 unsigned int nbClasses = this->GetNumberOfClasses();
459 int voisinage = m_Neighborhood / 2;
460 unsigned int componentIndex;
462 double neighborhoodWeight = (double)m_Neighborhood * m_Neighborhood;
465 typename TInputImage::SizeType size = m_Sample->GetBufferedRegion().GetSize();
469 std::vector<double> pdf(nbClasses);
470 std::vector<double> localWeight(nbClasses);
471 std::vector<double> localCount(nbClasses);
473 typename SampleType::ConstIterator iterSample = m_SampleList->Begin();
474 typename SampleType::ConstIterator lastSample = m_SampleList->End();
477 typename SampleType::InstanceIdentifier
id = 0;
481 id = iterSample.GetInstanceIdentifier();
483 for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
484 localCount[componentIndex] = 0.0;
489 for (a = (i - voisinage); a <= (i + voisinage); a++)
490 for (b = (j - voisinage); b <= (j + voisinage); b++)
492 if (a < 0 || a >= line)
495 if (b < 0 || b >= cols)
498 localCount[m_ClassLabels[a * cols + b]] += 1.0;
501 for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
502 localWeight[componentIndex] = localCount[componentIndex] / neighborhoodWeight;
505 for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
507 measurementVector = iterSample.GetMeasurementVector();
508 aPdf = localWeight[componentIndex] * m_ComponentVector[componentIndex]->Pdf(measurementVector);
510 pdf[componentIndex] = aPdf;
513 for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
516 m_Proba[componentIndex][iterSample.GetInstanceIdentifier()] = 0.0;
518 m_Proba[componentIndex][iterSample.GetInstanceIdentifier()] = pdf[componentIndex] / sumPdf;
521 }
while (++iterSample != lastSample);
524 template <
class TInputImage,
class TOutputImage>
527 unsigned int nbClasses = this->GetNumberOfClasses();
530 m_Output = OutputType::New();
531 m_Output->SetSample(this->GetSampleList());
534 m_Output->SetNumberOfClasses(nbClasses);
537 m_OutputImage = TOutputImage::New();
538 m_OutputImage->SetRegions(GetSample()->GetBufferedRegion());
539 m_OutputImage->Allocate();
542 unsigned int componentIndex;
544 typename SampleType::ConstIterator sampleIter = this->GetSampleList()->Begin();
545 typename SampleType::ConstIterator sampleIterEnd = this->GetSampleList()->End();
547 typename OutputType::ConstIterator outputIter = m_Output->Begin();
548 typename OutputType::ConstIterator outputIterEnd = m_Output->End();
550 typename itk::ImageRegionIterator<TOutputImage> imgOutputIter(m_OutputImage, m_OutputImage->GetBufferedRegion());
551 imgOutputIter.GoToBegin();
552 typename itk::ImageRegionIterator<TOutputImage> imgOutputIterEnd(m_OutputImage, m_OutputImage->GetBufferedRegion());
553 imgOutputIterEnd.GoToEnd();
558 for (componentIndex = 1; componentIndex < nbClasses; ++componentIndex)
560 if (m_Proba[componentIndex][sampleIter.GetInstanceIdentifier()] > m_Proba[cluster][sampleIter.GetInstanceIdentifier()])
561 cluster = componentIndex;
564 m_Output->AddInstance(cluster, sampleIter.GetInstanceIdentifier());
565 imgOutputIter.Set(cluster);
567 }
while (++sampleIter != sampleIterEnd && ++outputIter != outputIterEnd && ++imgOutputIter != imgOutputIterEnd);
570 template <
class TInputImage,
class TOutputImage>
576 m_CurrentIteration = 0;
577 m_TerminationCode = NOT_CONVERGED;
584 oldNbChange = m_NbChange;
586 PerformStochasticProcess();
587 PerformExpectationProcess();
588 PerformMaximizationProcess();
590 step =
static_cast<double>(oldNbChange - m_NbChange);
593 if ((step /
static_cast<double>(m_NbSamples)) < GetTerminationThreshold())
595 m_TerminationCode = CONVERGED;
596 if (oldNbChange != 0)
600 }
while (++m_CurrentIteration < m_MaximumIteration);
602 GetMaximumAposterioriLabels();