OTB  10.0.0
Orfeo Toolbox
otbSEMClassifier.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2024 Centre National d'Etudes Spatiales (CNES)
3  * Copyright (C) 2007-2012 Institut Mines Telecom / Telecom Bretagne
4  *
5  * This file is part of Orfeo Toolbox
6  *
7  * https://www.orfeo-toolbox.org/
8  *
9  * Licensed under the Apache License, Version 2.0 (the "License");
10  * you may not use this file except in compliance with the License.
11  * You may obtain a copy of the License at
12  *
13  * http://www.apache.org/licenses/LICENSE-2.0
14  *
15  * Unless required by applicable law or agreed to in writing, software
16  * distributed under the License is distributed on an "AS IS" BASIS,
17  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18  * See the License for the specific language governing permissions and
19  * limitations under the License.
20  */
21 
22 #ifndef otbSEMClassifier_hxx
23 #define otbSEMClassifier_hxx
24 
25 #include <cstdlib>
26 
27 #include <iostream>
28 #include "itkNumericTraits.h"
29 #include "itkImageIterator.h"
30 #include "itkImageRegionIterator.h"
31 
32 #include "otbMacro.h"
33 // default mixture model
34 
35 #include "otbSEMClassifier.h"
36 
37 namespace otb
38 {
39 
40 template <class TInputImage, class TOutputImage>
42 {
43  m_TerminationCode = NOT_CONVERGED;
44  m_ExternalLabels = 0;
45  m_ComponentDeclared = 0;
46  m_Sample = nullptr;
47  m_NbSamples = 0;
48  m_SampleList = nullptr;
49  m_NbChange = 0;
50  m_TerminationThreshold = 1E-5;
51  m_Neighborhood = 1;
52 
53  m_OutputImage = nullptr;
54  m_Output = nullptr;
55 }
56 
57 template <class TInputImage, class TOutputImage>
58 void SEMClassifier<TInputImage, TOutputImage>::PrintSelf(std::ostream& os, itk::Indent indent) const
59 {
60  Superclass::PrintSelf(os, indent);
61 
62  const unsigned int nbClasses = this->GetNumberOfClasses();
63 
64  for (unsigned int componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
65  {
66  os << indent << "Component num " << componentIndex;
67  os << " (prop " << m_Proportions[componentIndex] << ") ";
68  m_ComponentVector[componentIndex]->ShowParameters(os, indent);
69  }
70 }
71 
72 template <class TInputImage, class TOutputImage>
74 {
75  m_Neighborhood = 2 * (neighborhood / 2) + 1;
76  if (m_Neighborhood < 1)
77  m_Neighborhood = 1;
78 }
79 
80 template <class TInputImage, class TOutputImage>
82 {
83  return m_Neighborhood;
84 }
85 
86 template <class TInputImage, class TOutputImage>
88 {
89  m_InitialProportions = proportions;
90  m_ExternalLabels = 0;
91 }
92 
93 template <class TInputImage, class TOutputImage>
95 {
96  if (m_Sample == NULL)
97  {
98  m_ClassLabels.resize(labels->Size());
99 
100  ClassLabelVectorType::iterator iterClassLabel = m_ClassLabels.begin();
101  typename OutputType::ConstIterator iterLabels = labels->Begin();
102  typename OutputType::InstanceIdentifier id = 0;
103 
104  do
105  {
106  *iterClassLabel = iterLabels->GetClassLabel(id);
107  ++iterLabels;
108  ++iterClassLabel;
109  id++;
110  } while (iterLabels != labels->End());
111  m_ExternalLabels = 1;
112  }
113  else if (m_ClassLabels.size() == m_NbSamples) // FIXME check if this
114  // is really the right condition
115  {
116  m_ClassLabels.resize(labels->Size());
117 
118  ClassLabelVectorType::iterator iterClassLabel = m_ClassLabels.begin();
119  typename OutputType::iterator iterLabels = labels->Begin();
120  typename OutputType::InstanceIdentifier id = 0;
121 
122  do
123  {
124  *iterClassLabel = iterLabels->GetClassLabel(id);
125  ++iterLabels;
126  ++iterClassLabel;
127  id++;
128  } while (iterLabels != labels->End());
129  m_ExternalLabels = 1;
130  }
131  else
132  {
133  otbMsgDebugMacro(<< "m_ClassLabels size = " << GetClassLabels().size() << " / m_Sample size = " << m_NbSamples);
134  throw itk::ExceptionObject(__FILE__, __LINE__, "Vector size mismatch", ITK_LOCATION);
135  }
136 }
137 
138 template <class TInputImage, class TOutputImage>
140 {
141  otbMsgDebugMacro(<< "Initializing segmentation from an external image");
142 
143  typename TInputImage::SizeType size = imgLabels->GetBufferedRegion().GetSize();
144  int theSize = (int)size[0] * size[1];
145 
146  if (m_Sample == nullptr)
147  {
148  m_ClassLabels.resize(theSize);
149 
150  typename itk::ImageRegionIterator<TOutputImage> imgLabelIter(imgLabels, imgLabels->GetBufferedRegion());
151  imgLabelIter.GoToBegin();
152  typename itk::ImageRegionIterator<TOutputImage> imgLabelIterEnd(imgLabels, imgLabels->GetBufferedRegion());
153  imgLabelIterEnd.GoToEnd();
154 
155  ClassLabelVectorType::iterator iterClassLabel = m_ClassLabels.begin();
156 
157  do
158  {
159  *iterClassLabel = imgLabelIter.Get();
160  ++imgLabelIter;
161  ++iterClassLabel;
162  } while (imgLabelIter != imgLabelIterEnd);
163  m_ExternalLabels = 1;
164  }
165  else if (theSize == m_NbSamples)
166  {
167  m_ClassLabels.resize(theSize);
168 
169  typename itk::ImageRegionIterator<TOutputImage> imgLabelIter(imgLabels, imgLabels->GetBufferedRegion());
170  imgLabelIter.GoToBegin();
171  typename itk::ImageRegionIterator<TOutputImage> imgLabelIterEnd(imgLabels, imgLabels->GetBufferedRegion());
172  imgLabelIterEnd.GoToEnd();
173 
174  ClassLabelVectorType::iterator iterClassLabel = m_ClassLabels.begin();
175 
176  do
177  {
178  *iterClassLabel = imgLabelIter.Get();
179  ++imgLabelIter;
180  ++iterClassLabel;
181  } while (imgLabelIter != imgLabelIterEnd);
182  m_ExternalLabels = 1;
183  }
184  else
185  {
186  otbMsgDebugMacro(<< "m_ClassLabels size = " << GetClassLabels().size() << " size of the image = " << theSize << " / m_Sample size = " << m_NbSamples);
187  throw itk::ExceptionObject(__FILE__, __LINE__, "Vector size mismatch", ITK_LOCATION);
188  }
189 }
190 
191 template <class TInputImage, class TOutputImage>
193 {
194  return m_ClassLabels;
195 }
196 
197 template <class TInputImage, class TOutputImage>
199 {
200  m_Sample = sample;
201  m_NbSamples = 0;
202  m_SampleList = SampleType::New();
203  m_SampleList->SetMeasurementVectorSize(m_Sample->GetVectorLength());
204 
205  itk::ImageRegionIterator<TInputImage> imgIter((TInputImage*)m_Sample, m_Sample->GetBufferedRegion());
206  imgIter.GoToBegin();
207  itk::ImageRegionIterator<TInputImage> imgIterEnd((TInputImage*)m_Sample, m_Sample->GetBufferedRegion());
208  imgIterEnd.GoToEnd();
209 
210  do
211  {
212  m_SampleList->PushBack(imgIter.Get());
213  ++m_NbSamples;
214  ++imgIter;
215  } while (imgIter != imgIterEnd);
216 
217  if (m_ExternalLabels)
218  {
219  typename TInputImage::SizeType size = m_Sample->GetBufferedRegion().GetSize();
220  if ((size[0] * size[1]) != m_ClassLabels.size())
221  throw itk::ExceptionObject(__FILE__, __LINE__, "Vector size mismatch", ITK_LOCATION);
222  }
223 }
224 
225 template <class TInputImage, class TOutputImage>
227 {
228  return m_Sample;
229 }
230 
231 template <class TInputImage, class TOutputImage>
233 {
234  return m_SampleList;
235 }
236 
237 template <class TInputImage, class TOutputImage>
239 {
240  return m_CurrentIteration;
241 }
242 
243 template <class TInputImage, class TOutputImage>
245 {
246  m_ComponentVector[id] = component;
247  m_ComponentDeclared = 1;
248 
249  return static_cast<int>(m_ComponentVector.size());
250 }
251 
252 template <class TInputImage, class TOutputImage>
254 {
255  Superclass::Modified();
256  if (m_ComponentDeclared == 1)
257  otbMsgDebugMacro(<< "Previous component declarations will be lost since called before SetNumberOfClasses");
258  m_ComponentVector.clear();
259  m_ComponentVector.resize(this->GetNumberOfClasses());
260  m_ComponentDeclared = 0;
261 }
262 
263 template <class TInputImage, class TOutputImage>
265 {
266  return m_OutputImage;
267 }
268 
269 template <class TInputImage, class TOutputImage>
271 {
272  return m_Output;
273 }
274 
275 template <class TInputImage, class TOutputImage>
277 {
278  unsigned int nbClasses = this->GetNumberOfClasses();
279 
280  if (!m_ExternalLabels)
281  {
282  m_ClassLabels.resize(m_NbSamples);
283  if (static_cast<unsigned int>(m_InitialProportions.size()) != nbClasses)
284  {
285  unsigned int label;
286  for (typename ClassLabelVectorType::iterator labelIter = m_ClassLabels.begin(); labelIter != m_ClassLabels.end(); ++labelIter)
287  {
288  // label = (int) floor( 0.5 + nbClassesDbl * ran / double(RAND_MAX+1) );
289  label = rand() % nbClasses;
290  if (label >= nbClasses)
291  {
292  label = nbClasses - 1;
293  }
294  *labelIter = label;
295  }
296  }
297  else
298  {
299  // Be sure, the sum of initial proportion remains to 1
300  double sumProportion = 0.0;
301  typename ProportionVectorType::iterator iterProportion = m_InitialProportions.begin();
302  do
303  {
304  sumProportion += *iterProportion;
305  } while (++iterProportion != m_InitialProportions.end());
306 
307  if (sumProportion != 1.0)
308  {
309  for (iterProportion = m_InitialProportions.begin(); iterProportion != m_InitialProportions.end(); ++iterProportion)
310  *iterProportion /= sumProportion;
311  }
312 
313  // non uniform random sampling according to m_InitialProportions
314  double sample;
315  double cumulativeProportion;
316  for (typename ClassLabelVectorType::iterator labelIter = m_ClassLabels.begin(); labelIter != m_ClassLabels.end(); ++labelIter)
317  {
318  cumulativeProportion = 0.0;
319  sample = double(rand()) / (double(RAND_MAX) + 1.0);
320 
321  *labelIter = nbClasses - 1;
322  for (unsigned int componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
323  {
324  if (cumulativeProportion <= sample && sample < cumulativeProportion + m_InitialProportions[componentIndex])
325  {
326  *labelIter = componentIndex;
327  break;
328  }
329  else
330  cumulativeProportion += m_InitialProportions[componentIndex];
331  }
332  }
333  }
334  }
335 
336  m_Proportions.resize(nbClasses);
337  m_Proba.resize(nbClasses);
338  for (unsigned int i = 0; i < nbClasses; ++i)
339  m_Proba[i].resize(m_NbSamples);
340 
341  if (!m_ComponentDeclared)
342  {
343  otbMsgDebugMacro(<< "default mixture initialization with " << nbClasses << " Gaussian components");
345 
346  for (unsigned int componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
347  {
348  typename GaussianType::Pointer comp = GaussianType::New();
349  AddComponent(componentIndex, comp);
350  }
351  }
352 
353  otbMsgDevMacro(<< "num class : " << nbClasses);
354  otbMsgDevMacro(<< "num sample : " << GetSampleList()->Size());
355  otbMsgDevMacro(<< "num labels : " << GetClassLabels().size());
356  otbMsgDevMacro(<< "contextual neighborhood : " << m_Neighborhood);
357  otbMsgDevMacro(<< "terminationThreshold : " << m_TerminationThreshold);
358 }
359 
360 template <class TInputImage, class TOutputImage>
362 {
363  unsigned int nbClasses = this->GetNumberOfClasses();
364 
365  double x, y, z;
366  m_NbChange = 0;
367 
368  int posSample = 0;
369  for (typename ClassLabelVectorType::iterator iter = m_ClassLabels.begin(); iter != m_ClassLabels.end(); ++iter)
370  {
371  x = double(rand()) / (double(RAND_MAX) + 1.0);
372  z = 0.0;
373 
374  for (unsigned int componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
375  {
376  y = z;
377  z += m_Proba[componentIndex][posSample];
378 
379  if ((y < x) && (x <= z))
380  {
381  if (static_cast<unsigned int>(componentIndex) != *iter)
382  m_NbChange++;
383 
384  *iter = componentIndex;
385  break;
386  }
387  }
388  posSample++;
389  }
390 
391  switch (GetCurrentIteration())
392  {
393  case 0:
394  case 1:
395  otbMsgDebugMacro(<< "Doing iteration " << GetCurrentIteration());
396  break;
397  default:
398  otbMsgDebugMacro(<< m_NbChange << " sample change at iteration " << GetCurrentIteration());
399  }
400 }
401 
402 template <class TInputImage, class TOutputImage>
404 {
405  unsigned int nbClasses = this->GetNumberOfClasses();
406 
407  unsigned int componentIndex;
408  for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
409  m_Proportions[componentIndex] = 0.0;
410 
411  std::vector<typename ClassSampleType::Pointer> coeffByClass;
412 
413  for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
414  {
415  coeffByClass.push_back(ClassSampleType::New());
416  coeffByClass[componentIndex]->SetMeasurementVectorSize(m_SampleList->GetMeasurementVectorSize());
417  coeffByClass[componentIndex]->SetSample(m_SampleList);
418  }
419 
420  typename SampleType::ConstIterator iterSample = m_SampleList->Begin();
421  typename SampleType::ConstIterator lastSample = m_SampleList->End();
422 
423  ClassLabelVectorType::iterator iterLabel = m_ClassLabels.begin();
424  ClassLabelVectorType::iterator lastLabel = m_ClassLabels.end();
425 
426  typename SampleType::InstanceIdentifier id = 0;
427 
428  do
429  {
430  coeffByClass[*iterLabel]->AddInstance(id);
431  m_Proportions[*iterLabel] += 1.0;
432  id++;
433  } while (++iterSample != lastSample && ++iterLabel != lastLabel);
434 
435  for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
436  {
437  if (m_Proportions[componentIndex] == 0.0)
438  {
439  std::cerr << "No sample on class " << componentIndex;
440  // std::cerr << " in " << __PRETTY_FUNCTION__ << std::endl;
441  std::cerr << " in " << ITK_LOCATION << std::endl;
442  continue;
443  }
444 
445  m_ComponentVector[componentIndex]->SetSample(coeffByClass[componentIndex]);
446  m_ComponentVector[componentIndex]->Update();
447  }
448 
449  for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
450  m_Proportions[componentIndex] /= static_cast<double>(m_NbSamples);
451 }
452 
453 template <class TInputImage, class TOutputImage>
455 {
456  unsigned int nbClasses = this->GetNumberOfClasses();
457 
458  int i, j, a, b;
459  int voisinage = m_Neighborhood / 2;
460  unsigned int componentIndex;
461  double sumPdf, aPdf;
462  double neighborhoodWeight = (double)m_Neighborhood * m_Neighborhood;
463 
464  int line, cols;
465  typename TInputImage::SizeType size = m_Sample->GetBufferedRegion().GetSize();
466  cols = (int)size[0];
467  line = (int)size[1];
468 
469  std::vector<double> pdf(nbClasses);
470  std::vector<double> localWeight(nbClasses);
471  std::vector<double> localCount(nbClasses);
472 
473  typename SampleType::ConstIterator iterSample = m_SampleList->Begin();
474  typename SampleType::ConstIterator lastSample = m_SampleList->End();
475  MeasurementVectorType measurementVector;
476 
477  typename SampleType::InstanceIdentifier id = 0;
478 
479  do
480  {
481  id = iterSample.GetInstanceIdentifier();
482 
483  for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
484  localCount[componentIndex] = 0.0;
485 
486  i = id / cols;
487  j = id % cols;
488 
489  for (a = (i - voisinage); a <= (i + voisinage); a++)
490  for (b = (j - voisinage); b <= (j + voisinage); b++)
491  {
492  if (a < 0 || a >= line)
493  continue;
494 
495  if (b < 0 || b >= cols)
496  continue;
497 
498  localCount[m_ClassLabels[a * cols + b]] += 1.0;
499  }
500 
501  for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
502  localWeight[componentIndex] = localCount[componentIndex] / neighborhoodWeight;
503 
504  sumPdf = 0.0;
505  for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
506  {
507  measurementVector = iterSample.GetMeasurementVector();
508  aPdf = localWeight[componentIndex] * m_ComponentVector[componentIndex]->Pdf(measurementVector);
509  sumPdf += aPdf;
510  pdf[componentIndex] = aPdf;
511  }
512 
513  for (componentIndex = 0; componentIndex < nbClasses; ++componentIndex)
514  {
515  if (sumPdf == 0.0)
516  m_Proba[componentIndex][iterSample.GetInstanceIdentifier()] = 0.0;
517  else
518  m_Proba[componentIndex][iterSample.GetInstanceIdentifier()] = pdf[componentIndex] / sumPdf;
519  }
520 
521  } while (++iterSample != lastSample);
522 }
523 
524 template <class TInputImage, class TOutputImage>
526 {
527  unsigned int nbClasses = this->GetNumberOfClasses();
528 
529  // Class results initialization
530  m_Output = OutputType::New();
531  m_Output->SetSample(this->GetSampleList());
532  // m_Output->Resize(this->GetSampleList()->Size()); //FIXME check if
533  // still necessary
534  m_Output->SetNumberOfClasses(nbClasses);
535 
536  // Image results classification
537  m_OutputImage = TOutputImage::New();
538  m_OutputImage->SetRegions(GetSample()->GetBufferedRegion());
539  m_OutputImage->Allocate();
540 
541  int cluster;
542  unsigned int componentIndex;
543 
544  typename SampleType::ConstIterator sampleIter = this->GetSampleList()->Begin();
545  typename SampleType::ConstIterator sampleIterEnd = this->GetSampleList()->End();
546 
547  typename OutputType::ConstIterator outputIter = m_Output->Begin();
548  typename OutputType::ConstIterator outputIterEnd = m_Output->End();
549 
550  typename itk::ImageRegionIterator<TOutputImage> imgOutputIter(m_OutputImage, m_OutputImage->GetBufferedRegion());
551  imgOutputIter.GoToBegin();
552  typename itk::ImageRegionIterator<TOutputImage> imgOutputIterEnd(m_OutputImage, m_OutputImage->GetBufferedRegion());
553  imgOutputIterEnd.GoToEnd();
554 
555  do
556  {
557  cluster = 0;
558  for (componentIndex = 1; componentIndex < nbClasses; ++componentIndex)
559  {
560  if (m_Proba[componentIndex][sampleIter.GetInstanceIdentifier()] > m_Proba[cluster][sampleIter.GetInstanceIdentifier()])
561  cluster = componentIndex;
562  }
563 
564  m_Output->AddInstance(cluster, sampleIter.GetInstanceIdentifier());
565  imgOutputIter.Set(cluster);
566 
567  } while (++sampleIter != sampleIterEnd && ++outputIter != outputIterEnd && ++imgOutputIter != imgOutputIterEnd);
568 }
569 
570 template <class TInputImage, class TOutputImage>
572 {
573 
574  InitParameters();
575 
576  m_CurrentIteration = 0;
577  m_TerminationCode = NOT_CONVERGED;
578 
579  int oldNbChange = 0;
580  double step;
581 
582  do
583  {
584  oldNbChange = m_NbChange;
585 
586  PerformStochasticProcess();
587  PerformExpectationProcess();
588  PerformMaximizationProcess();
589 
590  step = static_cast<double>(oldNbChange - m_NbChange);
591  if (step >= 0.0)
592  {
593  if ((step / static_cast<double>(m_NbSamples)) < GetTerminationThreshold())
594  {
595  m_TerminationCode = CONVERGED;
596  if (oldNbChange != 0)
597  break;
598  }
599  }
600  } while (++m_CurrentIteration < m_MaximumIteration);
601 
602  GetMaximumAposterioriLabels();
603 }
604 
605 } // end of namesapce otb
606 
607 #endif
TOutputImage * GetOutputImage()
std::vector< double > ProportionVectorType
void SetInitialProportions(ProportionVectorType &proportions)
SampleType * GetSampleList() const
void SetClassLabels(OutputType *labels)
int AddComponent(int id, ComponentType *component)
SampleType::MeasurementVectorType MeasurementVectorType
void Modified() const override
itk::Statistics::ListSample< typename TInputImage::PixelType > SampleType
OutputType * GetOutput()
void SetSample(const TInputImage *sample)
ClassLabelVectorType & GetClassLabels()
void Update() override
void SetNeighborhood(int neighborhood)
std::vector< ClassLabelType > ClassLabelVectorType
void PrintSelf(std::ostream &os, itk::Indent indent) const override
itk::Statistics::MembershipSample< SampleType > OutputType
const TInputImage * GetSample() const
is a component (derived from ModelComponentBase) for Gaussian class. This class is used in SEMClassif...
base class for distribution representation that supports analytical way to update the distribution pa...
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
#define otbMsgDebugMacro(x)
Definition: otbMacro.h:114
#define otbMsgDevMacro(x)
Definition: otbMacro.h:116