21 #ifndef otbTrainRandomForests_hxx
22 #define otbTrainRandomForests_hxx
31 template <
class TInputValue,
class TOutputValue>
32 void LearningApplicationBase<TInputValue, TOutputValue>::InitRandomForestsParams()
34 AddChoice(
"classifier.rf",
"Random forests classifier");
35 SetParameterDescription(
"classifier.rf",
"http://docs.opencv.org/modules/ml/doc/random_trees.html");
37 AddParameter(
ParameterType_Int,
"classifier.rf.max",
"Maximum depth of the tree");
38 SetParameterInt(
"classifier.rf.max", 5);
39 SetParameterDescription(
"classifier.rf.max",
40 "The depth of the tree. A low value will likely underfit and conversely a high value will likely overfit. "
41 "The optimal value can be obtained using cross validation or other suitable methods.");
44 AddParameter(
ParameterType_Int,
"classifier.rf.min",
"Minimum number of samples in each node");
45 SetParameterInt(
"classifier.rf.min", 10);
46 SetParameterDescription(
"classifier.rf.min",
47 "If the number of samples in a node is smaller than this parameter, "
48 "then the node will not be split. A reasonable value is a small percentage of the total data e.g. 1 percent.");
51 AddParameter(
ParameterType_Float,
"classifier.rf.ra",
"Termination Criteria for regression tree");
52 SetParameterFloat(
"classifier.rf.ra", 0.);
53 SetParameterDescription(
"classifier.rf.ra",
54 "If all absolute differences between an estimated value in a node "
55 "and the values of the train samples in this node are smaller than this regression accuracy parameter, "
56 "then the node will not be split.");
59 AddParameter(
ParameterType_Int,
"classifier.rf.cat",
"Cluster possible values of a categorical variable into K <= cat clusters to find a suboptimal split");
60 SetParameterInt(
"classifier.rf.cat", 10);
61 SetParameterDescription(
"classifier.rf.cat",
"Cluster possible values of a categorical variable into K <= cat clusters to find a suboptimal split.");
68 AddParameter(
ParameterType_Int,
"classifier.rf.var",
"Size of the randomly selected subset of features at each tree node");
69 SetParameterInt(
"classifier.rf.var", 0);
70 SetParameterDescription(
"classifier.rf.var",
71 "The size of the subset of features, randomly selected at each tree node, that are used to find the best split(s). "
72 "If you set it to 0, then the size will be set to the square root of the total number of features.");
75 AddParameter(
ParameterType_Int,
"classifier.rf.nbtrees",
"Maximum number of trees in the forest");
76 SetParameterInt(
"classifier.rf.nbtrees", 100);
77 SetParameterDescription(
"classifier.rf.nbtrees",
78 "The maximum number of trees in the forest. Typically, the more trees you have, the better the accuracy. "
79 "However, the improvement in accuracy generally diminishes and reaches an asymptote for a certain number of trees. "
80 "Also to keep in mind, increasing the number of trees increases the prediction time linearly.");
84 SetParameterFloat(
"classifier.rf.acc", 0.01);
85 SetParameterDescription(
"classifier.rf.acc",
"Sufficient accuracy (OOB error).");
91 template <
class TInputValue,
class TOutputValue>
92 void LearningApplicationBase<TInputValue, TOutputValue>::TrainRandomForests(
typename ListSampleType::Pointer trainingListSample,
93 typename TargetListSampleType::Pointer trainingLabeledListSample,
94 std::string modelPath)
97 typename RandomForestType::Pointer classifier = RandomForestType::New();
98 classifier->SetRegressionMode(this->m_RegressionFlag);
99 classifier->SetInputListSample(trainingListSample);
100 classifier->SetTargetListSample(trainingLabeledListSample);
101 classifier->SetMaxDepth(GetParameterInt(
"classifier.rf.max"));
102 classifier->SetMinSampleCount(GetParameterInt(
"classifier.rf.min"));
103 classifier->SetRegressionAccuracy(GetParameterFloat(
"classifier.rf.ra"));
104 classifier->SetMaxNumberOfCategories(GetParameterInt(
"classifier.rf.cat"));
105 classifier->SetMaxNumberOfVariables(GetParameterInt(
"classifier.rf.var"));
106 classifier->SetMaxNumberOfTrees(GetParameterInt(
"classifier.rf.nbtrees"));
107 classifier->SetForestAccuracy(GetParameterFloat(
"classifier.rf.acc"));
110 classifier->Save(modelPath);