21 #ifndef otbTrainKNN_hxx
22 #define otbTrainKNN_hxx
31 template <
class TInputValue,
class TOutputValue>
32 void LearningApplicationBase<TInputValue, TOutputValue>::InitKNNParams()
34 AddChoice(
"classifier.knn",
"KNN classifier");
35 SetParameterDescription(
"classifier.knn",
"http://docs.opencv.org/modules/ml/doc/k_nearest_neighbors.html");
39 SetParameterInt(
"classifier.knn.k", 32);
40 SetParameterDescription(
"classifier.knn.k",
"The number of neighbors to use.");
42 if (this->m_RegressionFlag)
46 SetParameterDescription(
"classifier.knn.rule",
"Decision rule for regression output");
48 AddChoice(
"classifier.knn.rule.mean",
"Mean of neighbors values");
49 SetParameterDescription(
"classifier.knn.rule.mean",
"Returns the mean of neighbors values");
51 AddChoice(
"classifier.knn.rule.median",
"Median of neighbors values");
52 SetParameterDescription(
"classifier.knn.rule.median",
"Returns the median of neighbors values");
56 template <
class TInputValue,
class TOutputValue>
57 void LearningApplicationBase<TInputValue, TOutputValue>::TrainKNN(
typename ListSampleType::Pointer trainingListSample,
58 typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath)
61 typename KNNType::Pointer knnClassifier = KNNType::New();
62 knnClassifier->SetRegressionMode(this->m_RegressionFlag);
63 knnClassifier->SetInputListSample(trainingListSample);
64 knnClassifier->SetTargetListSample(trainingLabeledListSample);
65 knnClassifier->SetK(GetParameterInt(
"classifier.knn.k"));
66 if (this->m_RegressionFlag)
68 std::string decision = this->GetParameterString(
"classifier.knn.rule");
69 if (decision ==
"mean")
71 knnClassifier->SetDecisionRule(KNNType::KNN_MEAN);
73 else if (decision ==
"median")
75 knnClassifier->SetDecisionRule(KNNType::KNN_MEDIAN);
79 knnClassifier->Train();
80 knnClassifier->Save(modelPath);