21 #ifndef otbKMeansAttributesLabelMapFilter_hxx
22 #define otbKMeansAttributesLabelMapFilter_hxx
25 #include "itkNumericTraits.h"
26 #include "itkMersenneTwisterRandomVariateGenerator.h"
31 template <
class TInputImage>
37 template <
class TInputImage>
40 m_LabelMapToSampleListFilter->SetInputLabelMap(m_InputLabelMap);
41 m_LabelMapToSampleListFilter->Update();
43 typename ListSampleType::Pointer listSamples =
const_cast<ListSampleType*
>(m_LabelMapToSampleListFilter->GetOutputSampleList());
44 typename TrainingListSampleType::Pointer trainingSamples =
const_cast<TrainingListSampleType*
>(m_LabelMapToSampleListFilter->GetOutputTrainingSampleList());
47 typename TreeGeneratorType::Pointer kdTreeGenerator = TreeGeneratorType::New();
48 kdTreeGenerator->SetSample(listSamples);
49 kdTreeGenerator->SetBucketSize(100);
50 kdTreeGenerator->Update();
52 unsigned int sampleSize = listSamples->GetMeasurementVector(0).Size();
53 const unsigned int OneClassNbCentroids = 10;
54 unsigned int numberOfCentroids = (m_NumberOfClasses == 1 ? OneClassNbCentroids : m_NumberOfClasses);
55 typename EstimatorType::ParametersType initialMeans(sampleSize * m_NumberOfClasses);
56 initialMeans.Fill(0.);
58 if (m_NumberOfClasses > 1)
61 for (
ClassLabelType classLabel = 0; classLabel < m_NumberOfClasses; ++classLabel)
63 typename TrainingListSampleType::ConstIterator it = trainingSamples->Begin();
66 for (it = trainingSamples->Begin(); it != trainingSamples->End(); ++it)
68 std::cout <<
" Training Samples is " << it.GetMeasurementVector()[0] << std::endl;
69 if (it.GetMeasurementVector()[0] == classLabel)
72 if (it == trainingSamples->End())
74 itkExceptionMacro(<<
"Unable to find a sample with class label " << classLabel);
77 typename ListSampleType::InstanceIdentifier identifier = it.GetInstanceIdentifier();
78 const typename ListSampleType::MeasurementVectorType& centroid = listSamples->GetMeasurementVector(identifier);
79 for (
unsigned int i = 0; i < centroid.Size(); ++i)
81 initialMeans[classLabel * sampleSize + i] = centroid[i];
87 typedef itk::Statistics::MersenneTwisterRandomVariateGenerator RandomGeneratorType;
88 RandomGeneratorType::Pointer randomGenerator = RandomGeneratorType::GetInstance();
89 unsigned int nbLabelObjects = listSamples->Size();
92 for (
unsigned int centroidId = 0; centroidId < numberOfCentroids; ++centroidId)
94 typename ListSampleType::InstanceIdentifier identifier = randomGenerator->GetIntegerVariate(nbLabelObjects - 1);
95 const typename ListSampleType::MeasurementVectorType& centroid = listSamples->GetMeasurementVector(identifier);
96 for (
unsigned int i = 0; i < centroid.Size(); ++i)
98 initialMeans[centroidId * sampleSize + i] = centroid[i];
105 typename EstimatorType::Pointer estimator = EstimatorType::New();
106 estimator->SetParameters(initialMeans);
107 estimator->SetKdTree(kdTreeGenerator->GetOutput());
108 estimator->SetMaximumIteration(10000);
109 estimator->SetCentroidPositionChangesThreshold(0.00001);
110 estimator->StartOptimization();
115 for (
unsigned int cId = 0; cId < numberOfCentroids; ++cId)
118 for (
unsigned int i = 0; i < sampleSize; ++i)
120 newCenter[i] = estimator->GetParameters()[cId * sampleSize + i];
122 m_Centroids.push_back(newCenter);