21 #ifndef otbLibSVMMachineLearningModel_hxx
22 #define otbLibSVMMachineLearningModel_hxx
34 template <
class TInputValue,
class TOutputValue>
37 this->SetSVMType(C_SVC);
38 this->SetKernelType(LINEAR);
39 this->SetPolynomialKernelDegree(3);
40 this->SetKernelGamma(1.);
41 this->SetKernelCoef0(1.);
44 this->SetEpsilon(1e-3);
46 this->SetDoProbabilityEstimates(
false);
47 this->DoShrinking(
true);
48 this->SetCacheSize(40);
49 this->m_ParameterOptimization =
false;
50 this->m_IsRegressionSupported =
true;
51 this->SetCVFolders(5);
52 this->m_InitialCrossValidationAccuracy = 0.;
53 this->m_FinalCrossValidationAccuracy = 0.;
54 this->m_CoarseOptimizationNumberOfSteps = 5;
55 this->m_FineOptimizationNumberOfSteps = 5;
58 this->m_Parameters.nr_weight = 0;
59 this->m_Parameters.weight_label =
nullptr;
60 this->m_Parameters.weight =
nullptr;
62 this->m_Model =
nullptr;
64 this->m_Problem.l = 0;
65 this->m_Problem.y =
nullptr;
66 this->m_Problem.x =
nullptr;
72 template <
class TInputValue,
class TOutputValue>
76 this->DeleteProblem();
80 template <
class TInputValue,
class TOutputValue>
83 this->DeleteProblem();
91 this->ConsistencyCheck();
94 this->OptimizeParameters();
97 m_Model = svm_train(&m_Problem, &m_Parameters);
99 this->m_ConfidenceIndex = this->HasProbabilities();
102 template <
class TInputValue,
class TOutputValue>
110 int svm_type = svm_get_svm_type(m_Model);
114 struct svm_node* x =
new struct svm_node[input.Size() + 1];
117 for (
unsigned int i = 0; i < input.Size(); i++)
120 x[i].value = input[i];
124 x[input.Size()].index = -1;
125 x[input.Size()].value = 0;
126 if (proba !=
nullptr && !this->m_ProbaIndex)
127 itkExceptionMacro(
"Probability per class not available for this classifier !");
129 if (quality !=
nullptr)
131 if (!this->m_ConfidenceIndex)
133 itkExceptionMacro(
"Confidence index not available for this classifier !");
135 if (this->m_ConfidenceMode == CM_INDEX)
137 if (svm_type == C_SVC || svm_type == NU_SVC)
140 unsigned int nr_class = svm_get_nr_class(m_Model);
141 double* prob_estimates =
new double[nr_class];
143 target[0] =
static_cast<TargetValueType>(svm_predict_probability(m_Model, x, prob_estimates));
144 double maxProb = 0.0;
145 double secProb = 0.0;
146 for (
unsigned int i = 0; i < nr_class; ++i)
148 if (maxProb < prob_estimates[i])
151 maxProb = prob_estimates[i];
153 else if (secProb < prob_estimates[i])
155 secProb = prob_estimates[i];
160 delete[] prob_estimates;
168 (*quality) = svm_get_svr_probability(m_Model);
171 else if (this->m_ConfidenceMode == CM_PROBA)
173 target[0] =
static_cast<TargetValueType>(svm_predict_probability(m_Model, x, quality));
175 else if (this->m_ConfidenceMode == CM_HYPER)
177 target[0] =
static_cast<TargetValueType>(svm_predict_values(m_Model, x, quality));
184 if (svm_check_probability_model(m_Model))
186 unsigned int nr_class = svm_get_nr_class(m_Model);
187 double* prob_estimates =
new double[nr_class];
188 target[0] =
static_cast<TargetValueType>(svm_predict_probability(m_Model, x, prob_estimates));
189 delete[] prob_estimates;
203 template <
class TInputValue,
class TOutputValue>
206 if (svm_save_model(filename.c_str(), m_Model) != 0)
208 itkExceptionMacro(<<
"Problem while saving SVM model " << filename);
212 template <
class TInputValue,
class TOutputValue>
216 m_Model = svm_load_model(filename.c_str());
217 if (m_Model ==
nullptr)
219 itkExceptionMacro(<<
"Problem while loading SVM model " << filename);
221 m_Parameters = m_Model->param;
223 this->m_ConfidenceIndex = this->HasProbabilities();
226 template <
class TInputValue,
class TOutputValue>
235 std::cerr <<
"Could not read file " << file << std::endl;
241 std::getline(ifs, line);
244 if (line.find(
"svm_type") != std::string::npos)
253 template <
class TInputValue,
class TOutputValue>
259 template <
class TInputValue,
class TOutputValue>
263 Superclass::PrintSelf(os, indent);
266 template <
class TInputValue,
class TOutputValue>
269 bool modelHasProba =
static_cast<bool>(svm_check_probability_model(m_Model));
270 int type = svm_get_svm_type(m_Model);
271 int cmMode = this->m_ConfidenceMode;
273 if (type == EPSILON_SVR || type == NU_SVR)
275 ret = (modelHasProba && cmMode == CM_INDEX);
277 else if (type == C_SVC || type == NU_SVC)
279 ret = (modelHasProba && (cmMode == CM_INDEX || cmMode == CM_PROBA)) || (cmMode == CM_HYPER);
284 template <
class TInputValue,
class TOutputValue>
288 typename InputListSampleType::Pointer samples = this->GetInputListSample();
289 typename TargetListSampleType::Pointer target = this->GetTargetListSample();
290 int probl = samples->Size();
294 itkExceptionMacro(<<
"No samples, can not build SVM problem.");
299 long int elements = samples->GetMeasurementVectorSize();
303 m_Problem.y =
new double[probl];
304 m_Problem.x =
new struct svm_node*[probl];
305 for (
int i = 0; i < probl; ++i)
307 m_Problem.x[i] =
new struct svm_node[elements + 1];
311 typename InputListSampleType::ConstIterator sIt = samples->Begin();
312 typename TargetListSampleType::ConstIterator tIt = target->Begin();
315 while (sIt != samples->End() && tIt != target->End())
318 m_Problem.y[sampleIndex] = tIt.GetMeasurementVector()[0];
320 for (
int k = 0; k < elements; ++k)
322 m_Problem.x[sampleIndex][k].index = k + 1;
323 m_Problem.x[sampleIndex][k].value = sample[k];
326 m_Problem.x[sampleIndex][elements].index = -1;
327 m_Problem.x[sampleIndex][elements].value = 0;
335 if (this->GetKernelGamma() == 0)
337 this->SetKernelGamma(1.0 /
static_cast<double>(elements));
341 m_TmpTarget.resize(m_Problem.l);
344 template <
class TInputValue,
class TOutputValue>
347 if (this->GetSVMType() == ONE_CLASS && this->GetDoProbabilityEstimates())
349 otbMsgDebugMacro(<<
"Disabling SVM probability estimates for ONE_CLASS SVM type.");
350 this->SetDoProbabilityEstimates(
false);
353 const char* error_msg = svm_check_parameter(&m_Problem, &m_Parameters);
357 std::string err(error_msg);
358 itkExceptionMacro(
"SVM parameter check failed : " << err);
362 template <
class TInputValue,
class TOutputValue>
367 delete[] m_Problem.y;
368 m_Problem.y =
nullptr;
372 for (
int i = 0; i < m_Problem.l; ++i)
376 delete[] m_Problem.x[i];
379 delete[] m_Problem.x;
380 m_Problem.x =
nullptr;
385 template <
class TInputValue,
class TOutputValue>
390 svm_free_and_destroy_model(&m_Model);
395 template <
class TInputValue,
class TOutputValue>
399 switch (this->GetKernelType())
425 template <
class TInputValue,
class TOutputValue>
428 double accuracy = 0.0;
430 unsigned int length = m_Problem.l;
431 if (length == 0 || m_TmpTarget.size() < length)
435 svm_cross_validation(&m_Problem, &m_Parameters, m_CVFolders, &m_TmpTarget[0]);
438 double total_correct = 0.;
439 for (
unsigned int i = 0; i < length; ++i)
441 if (m_TmpTarget[i] == m_Problem.y[i])
446 accuracy = total_correct / length;
452 template <
class TInputValue,
class TOutputValue>
456 typename CrossValidationFunctionType::Pointer crossValidationFunction = CrossValidationFunctionType::New();
457 crossValidationFunction->SetModel(
this);
459 typename CrossValidationFunctionType::ParametersType initialParameters, coarseBestParameters, fineBestParameters;
461 unsigned int nbParams = this->GetNumberOfKernelParameters();
462 initialParameters.SetSize(nbParams);
463 initialParameters[0] = this->GetC();
465 initialParameters[1] = this->GetKernelGamma();
467 initialParameters[2] = this->GetKernelCoef0();
469 m_InitialCrossValidationAccuracy = crossValidationFunction->GetValue(initialParameters);
470 m_FinalCrossValidationAccuracy = m_InitialCrossValidationAccuracy;
472 otbMsgDebugMacro(<<
"Initial accuracy : " << m_InitialCrossValidationAccuracy <<
", Parameters Optimization" << m_ParameterOptimization);
474 if (m_ParameterOptimization)
479 coarseNbSteps.Fill(m_CoarseOptimizationNumberOfSteps);
481 coarseOptimizer->SetNumberOfSteps(coarseNbSteps);
482 coarseOptimizer->SetCostFunction(crossValidationFunction);
483 coarseOptimizer->SetInitialPosition(initialParameters);
484 coarseOptimizer->StartOptimization();
486 coarseBestParameters = coarseOptimizer->GetMaximumMetricValuePosition();
488 otbMsgDevMacro(<<
"Coarse minimum accuracy: " << coarseOptimizer->GetMinimumMetricValue() <<
" " << coarseOptimizer->GetMinimumMetricValuePosition());
489 otbMsgDevMacro(<<
"Coarse maximum accuracy: " << coarseOptimizer->GetMaximumMetricValue() <<
" " << coarseOptimizer->GetMaximumMetricValuePosition());
493 fineNbSteps.Fill(m_FineOptimizationNumberOfSteps);
495 double stepLength = 1. /
static_cast<double>(m_FineOptimizationNumberOfSteps);
497 fineOptimizer->SetNumberOfSteps(fineNbSteps);
498 fineOptimizer->SetStepLength(stepLength);
499 fineOptimizer->SetCostFunction(crossValidationFunction);
500 fineOptimizer->SetInitialPosition(coarseBestParameters);
501 fineOptimizer->StartOptimization();
503 otbMsgDevMacro(<<
"Fine minimum accuracy: " << fineOptimizer->GetMinimumMetricValue() <<
" " << fineOptimizer->GetMinimumMetricValuePosition());
504 otbMsgDevMacro(<<
"Fine maximum accuracy: " << fineOptimizer->GetMaximumMetricValue() <<
" " << fineOptimizer->GetMaximumMetricValuePosition());
506 fineBestParameters = fineOptimizer->GetMaximumMetricValuePosition();
508 m_FinalCrossValidationAccuracy = fineOptimizer->GetMaximumMetricValue();
510 this->SetC(fineBestParameters[0]);
512 this->SetKernelGamma(fineBestParameters[1]);
514 this->SetKernelCoef0(fineBestParameters[2]);