21 #ifndef otbSVMMachineLearningModel_hxx
22 #define otbSVMMachineLearningModel_hxx
32 template <
class TInputValue,
class TOutputValue>
35 m_SVMModel(cv::ml::SVM::create()),
36 m_SVMType(
CvSVM::C_SVC),
37 m_KernelType(
CvSVM::RBF),
44 m_TermCriteriaType(CV_TERMCRIT_ITER),
46 m_Epsilon(FLT_EPSILON),
47 m_ParameterOptimization(false),
60 template <
class TInputValue,
class TOutputValue>
64 if (
bool(m_SVMType == CvSVM::NU_SVR || m_SVMType == CvSVM::EPS_SVR) != this->m_RegressionMode)
66 itkGenericExceptionMacro(
67 "SVM type incompatible with chosen mode (classification or regression."
68 "SVM types for classification are C_SVC, NU_SVC, ONE_CLASS. "
69 "SVM types for regression are NU_SVR, EPS_SVR");
74 otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
77 otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(), labels);
79 cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U);
82 if (!this->m_RegressionMode)
83 var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) =
CV_VAR_CATEGORICAL;
85 m_SVMModel->setType(m_SVMType);
86 m_SVMModel->setKernel(m_KernelType);
87 m_SVMModel->setDegree(m_Degree);
88 m_SVMModel->setGamma(m_Gamma);
89 m_SVMModel->setCoef0(m_Coef0);
90 m_SVMModel->setC(m_C);
91 m_SVMModel->setNu(m_Nu);
92 m_SVMModel->setP(m_P);
93 m_SVMModel->setTermCriteria(cv::TermCriteria(m_TermCriteriaType, m_MaxIter, m_Epsilon));
95 if (!m_ParameterOptimization)
97 m_SVMModel->train(cv::ml::TrainData::create(samples, cv::ml::ROW_SAMPLE, labels, cv::noArray(), cv::noArray(), cv::noArray(), var_type));
101 m_SVMModel->trainAuto(cv::ml::TrainData::create(samples, cv::ml::ROW_SAMPLE, labels, cv::noArray(), cv::noArray(), cv::noArray(), var_type));
104 m_OutputDegree = m_SVMModel->getDegree();
105 m_OutputGamma = m_SVMModel->getGamma();
106 m_OutputCoef0 = m_SVMModel->getCoef0();
107 m_OutputC = m_SVMModel->getC();
108 m_OutputNu = m_SVMModel->getNu();
109 m_OutputP = m_SVMModel->getP();
112 template <
class TInputValue,
class TOutputValue>
120 otb::SampleToMat<InputSampleType>(input, sample);
122 double result = m_SVMModel->predict(sample);
124 target[0] =
static_cast<TOutputValue
>(result);
126 if (quality !=
nullptr)
128 (*quality) = m_SVMModel->predict(sample, cv::noArray(), cv::ml::StatModel::RAW_OUTPUT);
130 if (proba !=
nullptr && !this->m_ProbaIndex)
131 itkExceptionMacro(
"Probability per class not available for this classifier !");
136 template <
class TInputValue,
class TOutputValue>
139 cv::FileStorage fs(filename, cv::FileStorage::WRITE);
140 fs << (name.empty() ? m_SVMModel->getDefaultName() : cv::String(name)) <<
"{";
141 m_SVMModel->write(fs);
146 template <
class TInputValue,
class TOutputValue>
149 cv::FileStorage fs(filename, cv::FileStorage::READ);
150 m_SVMModel->read(name.empty() ? fs.getFirstTopLevelNode() : fs[name]);
153 template <
class TInputValue,
class TOutputValue>
161 std::cerr <<
"Could not read file " << file << std::endl;
168 std::getline(ifs, line);
171 if (line.find(
CV_TYPE_NAME_ML_SVM) != std::string::npos || line.find(m_SVMModel->getDefaultName()) != std::string::npos)
180 template <
class TInputValue,
class TOutputValue>
186 template <
class TInputValue,
class TOutputValue>
190 Superclass::PrintSelf(os, indent);