21 #ifndef otbBoostMachineLearningModel_hxx
22 #define otbBoostMachineLearningModel_hxx
33 template <
class TInputValue,
class TOutputValue>
35 : m_BoostModel(cv::ml::Boost::create()),
38 m_WeightTrimRate(0.95),
45 template <
class TInputValue,
class TOutputValue>
50 otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
54 otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(), labels);
56 cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U);
58 var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) =
CV_VAR_CATEGORICAL;
60 m_BoostModel->setBoostType(m_BoostType);
61 m_BoostModel->setWeakCount(m_WeakCount);
62 m_BoostModel->setWeightTrimRate(m_WeightTrimRate);
63 m_BoostModel->setMaxDepth(m_MaxDepth);
64 m_BoostModel->setUseSurrogates(
false);
65 m_BoostModel->setPriors(cv::Mat());
66 m_BoostModel->train(cv::ml::TrainData::create(samples, cv::ml::ROW_SAMPLE, labels, cv::noArray(), cv::noArray(), cv::noArray(), var_type));
69 template <
class TInputValue,
class TOutputValue>
78 otb::SampleToMat<InputSampleType>(input, sample);
81 result = m_BoostModel->predict(sample);
83 if (quality !=
nullptr)
85 (*quality) =
static_cast<ConfidenceValueType>(m_BoostModel->predict(sample, cv::noArray(), cv::ml::StatModel::RAW_OUTPUT));
87 if (proba !=
nullptr && !this->m_ProbaIndex)
88 itkExceptionMacro(
"Probability per class not available for this classifier !");
90 target[0] =
static_cast<TOutputValue
>(result);
94 template <
class TInputValue,
class TOutputValue>
97 cv::FileStorage fs(filename, cv::FileStorage::WRITE);
98 fs << (name.empty() ? m_BoostModel->getDefaultName() : cv::String(name)) <<
"{";
99 m_BoostModel->write(fs);
104 template <
class TInputValue,
class TOutputValue>
107 cv::FileStorage fs(filename, cv::FileStorage::READ);
108 m_BoostModel->read(name.empty() ? fs.getFirstTopLevelNode() : fs[name]);
111 template <
class TInputValue,
class TOutputValue>
119 std::cerr <<
"Could not read file " << file << std::endl;
126 std::getline(ifs, line);
129 if (line.find(
CV_TYPE_NAME_ML_BOOSTING) != std::string::npos || line.find(m_BoostModel->getDefaultName()) != std::string::npos)
139 template <
class TInputValue,
class TOutputValue>
145 template <
class TInputValue,
class TOutputValue>
149 Superclass::PrintSelf(os, indent);