21 #ifndef otbTrainSharkRandomForests_hxx
22 #define otbTrainSharkRandomForests_hxx
32 template <
class TInputValue,
class TOutputValue>
33 void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkRandomForestsParams()
37 AddChoice(
"classifier.sharkrf",
"Shark Random forests classifier");
38 SetParameterDescription(
"classifier.sharkrf",
39 "http://image.diku.dk/shark/doxygen_pages/html/classshark_1_1_r_f_trainer.html.\n It is noteworthy that training is parallel.");
41 AddParameter(
ParameterType_Int,
"classifier.sharkrf.nbtrees",
"Maximum number of trees in the forest");
42 SetParameterInt(
"classifier.sharkrf.nbtrees", 100);
43 SetParameterDescription(
"classifier.sharkrf.nbtrees",
44 "The maximum number of trees in the forest. Typically, the more trees you have, the better the accuracy. "
45 "However, the improvement in accuracy generally diminishes and reaches an asymptote for a certain number of trees. "
46 "Also to keep in mind, increasing the number of trees increases the prediction time linearly.");
50 AddParameter(
ParameterType_Int,
"classifier.sharkrf.nodesize",
"Min size of the node for a split");
51 SetParameterInt(
"classifier.sharkrf.nodesize", 25);
52 SetParameterDescription(
"classifier.sharkrf.nodesize",
53 "If the number of samples in a node is smaller than this parameter, "
54 "then the node will not be split. A reasonable value is a small percentage of the total data e.g. 1 percent.");
57 AddParameter(
ParameterType_Int,
"classifier.sharkrf.mtry",
"Number of features tested at each node");
58 SetParameterInt(
"classifier.sharkrf.mtry", 0);
59 SetParameterDescription(
"classifier.sharkrf.mtry",
60 "The number of features (variables) which will be tested at each node in "
61 "order to compute the split. If set to zero, the square root of the number of "
67 SetParameterFloat(
"classifier.sharkrf.oobr", 0.66);
68 SetParameterDescription(
"classifier.sharkrf.oobr",
69 "Set the fraction of the original training dataset to use as the out of bag sample."
70 "A good default value is 0.66. ");
73 template <
class TInputValue,
class TOutputValue>
74 void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkRandomForests(
typename ListSampleType::Pointer trainingListSample,
75 typename TargetListSampleType::Pointer trainingLabeledListSample,
76 std::string modelPath)
79 typename SharkRandomForestType::Pointer classifier = SharkRandomForestType::New();
80 classifier->SetRegressionMode(this->m_RegressionFlag);
81 classifier->SetInputListSample(trainingListSample);
82 classifier->SetTargetListSample(trainingLabeledListSample);
83 classifier->SetNodeSize(GetParameterInt(
"classifier.sharkrf.nodesize"));
84 classifier->SetOobRatio(GetParameterFloat(
"classifier.sharkrf.oobr"));
85 classifier->SetNumberOfTrees(GetParameterInt(
"classifier.sharkrf.nbtrees"));
86 classifier->SetMTry(GetParameterInt(
"classifier.sharkrf.mtry"));
89 classifier->Save(modelPath);