21 #ifndef otbRandomForestsMachineLearningModel_hxx
22 #define otbRandomForestsMachineLearningModel_hxx
32 template <
class TInputValue,
class TOutputValue>
38 m_RegressionAccuracy(0.01),
39 m_ComputeSurrogateSplit(false),
40 m_MaxNumberOfCategories(10),
41 m_CalculateVariableImportance(false),
42 m_MaxNumberOfVariables(0),
43 m_MaxNumberOfTrees(100),
44 m_ForestAccuracy(0.01),
45 m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS),
46 m_ComputeMargin(false)
53 template <
class TInputValue,
class TOutputValue>
58 otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
61 otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(), labels);
63 cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U);
66 if (this->m_RegressionMode)
67 var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) =
CV_VAR_NUMERICAL;
69 var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) =
CV_VAR_CATEGORICAL;
71 return m_RFModel->calcError(cv::ml::TrainData::create(samples, cv::ml::ROW_SAMPLE, labels, cv::noArray(), cv::noArray(), cv::noArray(), var_type),
false,
76 template <
class TInputValue,
class TOutputValue>
81 otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
85 otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(), labels);
87 cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U);
90 if (this->m_RegressionMode)
91 var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) =
CV_VAR_NUMERICAL;
93 var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) =
CV_VAR_CATEGORICAL;
99 m_RFModel->setMaxDepth(m_MaxDepth);
100 m_RFModel->setMinSampleCount(m_MinSampleCount);
101 m_RFModel->setRegressionAccuracy(m_RegressionAccuracy);
102 m_RFModel->setUseSurrogates(m_ComputeSurrogateSplit);
103 m_RFModel->setMaxCategories(m_MaxNumberOfCategories);
104 m_RFModel->setPriors(cv::Mat(m_Priors));
105 m_RFModel->setCalculateVarImportance(m_CalculateVariableImportance);
106 m_RFModel->setActiveVarCount(m_MaxNumberOfVariables);
107 m_RFModel->setTermCriteria(cv::TermCriteria(m_TerminationCriteria, m_MaxNumberOfTrees, m_ForestAccuracy));
108 m_RFModel->train(cv::ml::TrainData::create(samples, cv::ml::ROW_SAMPLE, labels, cv::noArray(), cv::noArray(), cv::noArray(), var_type));
111 template <
class TInputValue,
class TOutputValue>
121 otb::SampleToMat<InputSampleType>(value, sample);
123 double result = m_RFModel->predict(sample);
125 target[0] =
static_cast<TOutputValue
>(result);
127 if (quality !=
nullptr)
130 (*quality) = m_RFModel->predict_margin(sample);
132 (*quality) = m_RFModel->predict_confidence(sample);
135 if (proba !=
nullptr && !this->m_ProbaIndex)
136 itkExceptionMacro(
"Probability per class not available for this classifier !");
141 template <
class TInputValue,
class TOutputValue>
144 cv::FileStorage fs(filename, cv::FileStorage::WRITE);
145 fs << (name.empty() ? m_RFModel->getDefaultName() : cv::String(name)) <<
"{";
146 m_RFModel->write(fs);
151 template <
class TInputValue,
class TOutputValue>
154 cv::FileStorage fs(filename, cv::FileStorage::READ);
155 m_RFModel->read(name.empty() ? fs.getFirstTopLevelNode() : fs[name]);
158 template <
class TInputValue,
class TOutputValue>
166 std::cerr <<
"Could not read file " << file << std::endl;
174 std::getline(ifs, line);
177 if (line.find(
CV_TYPE_NAME_ML_RTREES) != std::string::npos || line.find(m_RFModel->getDefaultName()) != std::string::npos)
186 template <
class TInputValue,
class TOutputValue>
192 template <
class TInputValue,
class TOutputValue>
196 cv::Mat cvMat = m_RFModel->getVarImportance();
198 for (
int i = 0; i < cvMat.rows; i++)
200 for (
int j = 0; j < cvMat.cols; j++)
202 itkMat(i, j) = cvMat.at<
float>(i, j);
209 template <
class TInputValue,
class TOutputValue>
213 Superclass::PrintSelf(os, indent);