21 #ifndef otbTrainDecisionTree_hxx
22 #define otbTrainDecisionTree_hxx
31 template <
class TInputValue,
class TOutputValue>
32 void LearningApplicationBase<TInputValue, TOutputValue>::InitDecisionTreeParams()
34 AddChoice(
"classifier.dt",
"Decision Tree classifier");
35 SetParameterDescription(
"classifier.dt",
"http://docs.opencv.org/modules/ml/doc/decision_trees.html");
37 AddParameter(
ParameterType_Int,
"classifier.dt.max",
"Maximum depth of the tree");
38 SetParameterInt(
"classifier.dt.max", 10);
39 SetParameterDescription(
"classifier.dt.max",
40 "The training algorithm attempts to split each node while its depth is smaller "
41 "than the maximum possible depth of the tree. The actual depth may be smaller "
42 "if the other termination criteria are met, and/or if the tree is pruned.");
45 AddParameter(
ParameterType_Int,
"classifier.dt.min",
"Minimum number of samples in each node");
46 SetParameterInt(
"classifier.dt.min", 10);
47 SetParameterDescription(
"classifier.dt.min",
48 "If the number of samples in a node is smaller "
49 "than this parameter, then this node will not be split.");
52 AddParameter(
ParameterType_Float,
"classifier.dt.ra",
"Termination criteria for regression tree");
53 SetParameterFloat(
"classifier.dt.ra", 0.01);
54 SetParameterDescription(
"classifier.dt.ra",
55 "If all absolute differences between an estimated value in a node "
56 "and the values of the train samples in this node are smaller than this "
57 "regression accuracy parameter, then the node will not be split further.");
64 "Cluster possible values of a categorical variable into K <= cat clusters to find a "
66 SetParameterInt(
"classifier.dt.cat", 10);
67 SetParameterDescription(
"classifier.dt.cat",
68 "Cluster possible values of a categorical variable into K <= cat clusters to find a "
73 SetParameterDescription(
"classifier.dt.r",
74 "If true, then a pruning will be harsher. This will make a tree more compact and more "
75 "resistant to the training data noise but a bit less accurate.");
78 AddParameter(
ParameterType_Bool,
"classifier.dt.t",
"Set TruncatePrunedTree flag to false");
79 SetParameterDescription(
"classifier.dt.t",
"If true, then pruned branches are physically removed from the tree.");
84 template <
class TInputValue,
class TOutputValue>
85 void LearningApplicationBase<TInputValue, TOutputValue>::TrainDecisionTree(
typename ListSampleType::Pointer trainingListSample,
86 typename TargetListSampleType::Pointer trainingLabeledListSample,
87 std::string modelPath)
90 typename DecisionTreeType::Pointer classifier = DecisionTreeType::New();
91 classifier->SetRegressionMode(this->m_RegressionFlag);
92 classifier->SetInputListSample(trainingListSample);
93 classifier->SetTargetListSample(trainingLabeledListSample);
94 classifier->SetMaxDepth(GetParameterInt(
"classifier.dt.max"));
95 classifier->SetMinSampleCount(GetParameterInt(
"classifier.dt.min"));
96 classifier->SetRegressionAccuracy(GetParameterFloat(
"classifier.dt.ra"));
97 classifier->SetMaxCategories(GetParameterInt(
"classifier.dt.cat"));
99 if (GetParameterInt(
"classifier.dt.r"))
101 classifier->SetUse1seRule(
false);
103 if (GetParameterInt(
"classifier.dt.t"))
105 classifier->SetTruncatePrunedTree(
false);
108 classifier->Save(modelPath);