21 #ifndef otbDecisionTreeMachineLearningModel_hxx
22 #define otbDecisionTreeMachineLearningModel_hxx
33 template <
class TInputValue,
class TOutputValue>
36 m_DTreeModel(cv::ml::DTrees::create()),
39 m_RegressionAccuracy(0.01),
40 m_UseSurrogates(false),
43 m_TruncatePrunedTree(true)
49 template <
class TInputValue,
class TOutputValue>
54 otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
58 otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(), labels);
60 cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U);
63 if (!this->m_RegressionMode)
64 var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) =
CV_VAR_CATEGORICAL;
66 m_DTreeModel->setMaxDepth(m_MaxDepth);
67 m_DTreeModel->setMinSampleCount(m_MinSampleCount);
68 m_DTreeModel->setRegressionAccuracy(m_RegressionAccuracy);
69 m_DTreeModel->setUseSurrogates(m_UseSurrogates);
71 m_DTreeModel->setCVFolds(0);
72 m_DTreeModel->setMaxCategories(m_MaxCategories);
73 m_DTreeModel->setUse1SERule(m_Use1seRule);
74 m_DTreeModel->setTruncatePrunedTree(m_TruncatePrunedTree);
75 m_DTreeModel->setPriors(cv::Mat(m_Priors));
76 m_DTreeModel->train(cv::ml::TrainData::create(samples, cv::ml::ROW_SAMPLE, labels, cv::noArray(), cv::noArray(), cv::noArray(), var_type));
79 template <
class TInputValue,
class TOutputValue>
88 otb::SampleToMat<InputSampleType>(input, sample);
89 double result = m_DTreeModel->predict(sample);
91 target[0] =
static_cast<TOutputValue
>(result);
93 if (quality !=
nullptr)
95 if (!this->m_ConfidenceIndex)
97 itkExceptionMacro(
"Confidence index not available for this classifier !");
100 if (proba !=
nullptr && !this->m_ProbaIndex)
101 itkExceptionMacro(
"Probability per class not available for this classifier !");
106 template <
class TInputValue,
class TOutputValue>
109 cv::FileStorage fs(filename, cv::FileStorage::WRITE);
110 fs << (name.empty() ? m_DTreeModel->getDefaultName() : cv::String(name)) <<
"{";
111 m_DTreeModel->write(fs);
116 template <
class TInputValue,
class TOutputValue>
119 cv::FileStorage fs(filename, cv::FileStorage::READ);
120 m_DTreeModel->read(name.empty() ? fs.getFirstTopLevelNode() : fs[name]);
123 template <
class TInputValue,
class TOutputValue>
131 std::cerr <<
"Could not read file " << file << std::endl;
138 std::getline(ifs, line);
141 if (line.find(
CV_TYPE_NAME_ML_TREE) != std::string::npos || line.find(m_DTreeModel->getDefaultName()) != std::string::npos)
150 template <
class TInputValue,
class TOutputValue>
156 template <
class TInputValue,
class TOutputValue>
160 Superclass::PrintSelf(os, indent);