20 #ifndef otbTrainSharkKMeans_hxx
21 #define otbTrainSharkKMeans_hxx
31 template <
class TInputValue,
class TOutputValue>
32 void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams()
34 AddChoice(
"classifier.sharkkm",
"Shark kmeans classifier");
35 SetParameterDescription(
"classifier.sharkkm",
"http://image.diku.dk/shark/sphinx_pages/build/html/rest_sources/tutorials/algorithms/kmeans.html ");
38 AddParameter(
ParameterType_Int,
"classifier.sharkkm.maxiter",
"Maximum number of iterations for the kmeans algorithm");
39 SetParameterInt(
"classifier.sharkkm.maxiter", 10);
40 SetMinimumParameterIntValue(
"classifier.sharkkm.maxiter", 0);
41 SetParameterDescription(
"classifier.sharkkm.maxiter",
"The maximum number of iterations for the kmeans algorithm. 0=unlimited");
44 AddParameter(
ParameterType_Int,
"classifier.sharkkm.k",
"Number of classes for the kmeans algorithm");
45 SetParameterInt(
"classifier.sharkkm.k", 2);
46 SetParameterDescription(
"classifier.sharkkm.k",
"The number of classes used for the kmeans algorithm. Default set to 2 class");
47 SetMinimumParameterIntValue(
"classifier.sharkkm.k", 2);
51 SetParameterDescription(
"classifier.sharkkm.incentroids",
52 "Input text file containing centroid posistions used to initialize the algorithm. "
53 "Each centroid must be described by p parameters, p being the number of features in "
54 "the input vector data, and the number of centroids must be equal to the number of classes "
55 "(one centroid per line with values separated by spaces).");
56 MandatoryOff(
"classifier.sharkkm.incentroids");
60 SetParameterDescription(
"classifier.sharkkm.cstats",
61 "A XML file containing mean and standard deviation to center"
62 "and reduce the input centroids before the KMeans algorithm, produced by ComputeImagesStatistics application.");
63 MandatoryOff(
"classifier.sharkkm.cstats");
67 SetParameterDescription(
"classifier.sharkkm.outcentroids",
"Output text file containing centroids after the kmean algorithm.");
68 MandatoryOff(
"classifier.sharkkm.outcentroids");
71 template <
class TInputValue,
class TOutputValue>
72 void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans(
typename ListSampleType::Pointer trainingListSample,
73 typename TargetListSampleType::Pointer trainingLabeledListSample,
74 std::string modelPath)
76 unsigned int nbMaxIter =
static_cast<unsigned int>(abs(GetParameterInt(
"classifier.sharkkm.maxiter")));
77 unsigned int k =
static_cast<unsigned int>(abs(GetParameterInt(
"classifier.sharkkm.k")));
80 typename SharkKMeansType::Pointer classifier = SharkKMeansType::New();
81 classifier->SetRegressionMode(this->m_RegressionFlag);
82 classifier->SetInputListSample(trainingListSample);
83 classifier->SetTargetListSample(trainingLabeledListSample);
87 if (IsParameterEnabled(
"classifier.sharkkm.incentroids") && HasValue(
"classifier.sharkkm.incentroids"))
89 shark::Data<shark::RealVector> centroidData;
90 shark::importCSV(centroidData, GetParameterString(
"classifier.sharkkm.incentroids"),
' ');
91 if (HasValue(
"classifier.sharkkm.cstats"))
94 statisticsReader->SetFileName(GetParameterString(
"classifier.sharkkm.cstats"));
95 auto meanMeasurementVector = statisticsReader->GetStatisticVectorByName(
"mean");
96 auto stddevMeasurementVector = statisticsReader->GetStatisticVectorByName(
"stddev");
99 shark::RealVector offsetRV(meanMeasurementVector.Size());
100 shark::RealVector scaleRV(stddevMeasurementVector.Size());
102 assert(meanMeasurementVector.Size() == stddevMeasurementVector.Size());
103 for (
unsigned int i = 0; i < meanMeasurementVector.Size(); ++i)
105 scaleRV[i] = 1 / stddevMeasurementVector[i];
107 offsetRV[i] = -meanMeasurementVector[i] / stddevMeasurementVector[i];
110 shark::Normalizer<> normalizer(scaleRV, offsetRV);
111 centroidData = normalizer(centroidData);
114 if (centroidData.numberOfElements() != k)
115 otbAppLogWARNING(
"The input centroid file will not be used because it contains "
116 << centroidData.numberOfElements() <<
" points, which is different than from the requested number of class: " << k <<
".");
118 classifier->SetCentroidsFromData(centroidData);
121 classifier->SetMaximumNumberOfIterations(nbMaxIter);
123 classifier->Save(modelPath);
125 if (HasValue(
"classifier.sharkkm.outcentroids"))
126 classifier->ExportCentroids(GetParameterString(
"classifier.sharkkm.outcentroids"));