21 #ifndef otbTrainSVM_hxx
22 #define otbTrainSVM_hxx
31 template <
class TInputValue,
class TOutputValue>
32 void LearningApplicationBase<TInputValue, TOutputValue>::InitSVMParams()
34 AddChoice(
"classifier.svm",
"SVM classifier (OpenCV)");
35 SetParameterDescription(
"classifier.svm",
"http://docs.opencv.org/modules/ml/doc/support_vector_machines.html");
37 SetParameterDescription(
"classifier.svm.m",
"Type of SVM formulation.");
38 if (this->m_RegressionFlag)
40 AddChoice(
"classifier.svm.m.epssvr",
"Epsilon Support Vector Regression");
41 AddChoice(
"classifier.svm.m.nusvr",
"Nu Support Vector Regression");
42 SetParameterString(
"classifier.svm.m",
"epssvr");
46 AddChoice(
"classifier.svm.m.csvc",
"C support vector classification");
47 AddChoice(
"classifier.svm.m.nusvc",
"Nu support vector classification");
48 AddChoice(
"classifier.svm.m.oneclass",
"Distribution estimation (One Class SVM)");
49 SetParameterString(
"classifier.svm.m",
"csvc");
52 AddChoice(
"classifier.svm.k.linear",
"Linear");
54 AddChoice(
"classifier.svm.k.rbf",
"Gaussian radial basis function");
55 AddChoice(
"classifier.svm.k.poly",
"Polynomial");
56 AddChoice(
"classifier.svm.k.sigmoid",
"Sigmoid");
57 SetParameterString(
"classifier.svm.k",
"linear");
58 SetParameterDescription(
"classifier.svm.k",
"SVM Kernel Type.");
60 SetParameterFloat(
"classifier.svm.c", 1.0);
61 SetParameterDescription(
"classifier.svm.c",
62 "SVM models have a cost parameter C (1 by default) to control the trade-off"
63 " between training errors and forcing rigid margins.");
64 AddParameter(
ParameterType_Float,
"classifier.svm.nu",
"Parameter nu of a SVM optimization problem (NU_SVC / ONE_CLASS)");
65 SetParameterFloat(
"classifier.svm.nu", 0.0);
66 SetParameterDescription(
"classifier.svm.nu",
"Parameter nu of a SVM optimization problem.");
67 if (this->m_RegressionFlag)
69 AddParameter(
ParameterType_Float,
"classifier.svm.p",
"Parameter epsilon of a SVM optimization problem (EPS_SVR)");
70 SetParameterFloat(
"classifier.svm.p", 1.0);
71 SetParameterDescription(
"classifier.svm.p",
"Parameter epsilon of a SVM optimization problem (EPS_SVR).");
74 SetParameterDescription(
"classifier.svm.term",
"Termination criteria for iterative algorithm");
75 AddChoice(
"classifier.svm.term.iter",
"Stops when maximum iteration is reached.");
76 AddChoice(
"classifier.svm.term.eps",
"Stops when accuracy is lower than epsilon.");
77 AddChoice(
"classifier.svm.term.all",
"Stops when either iteration or epsilon criteria is true");
80 SetParameterFloat(
"classifier.svm.iter", 1000);
81 SetParameterDescription(
"classifier.svm.iter",
"Maximum number of iterations (corresponds to the termination criteria 'iter').");
84 SetParameterFloat(
"classifier.svm.eps", FLT_EPSILON);
85 SetParameterDescription(
"classifier.svm.eps",
"Epsilon accuracy (corresponds to the termination criteria 'eps').");
87 AddParameter(
ParameterType_Float,
"classifier.svm.coef0",
"Parameter coef0 of a kernel function (POLY / SIGMOID)");
88 SetParameterFloat(
"classifier.svm.coef0", 0.0);
89 SetParameterDescription(
"classifier.svm.coef0",
"Parameter coef0 of a kernel function (POLY / SIGMOID).");
90 AddParameter(
ParameterType_Float,
"classifier.svm.gamma",
"Parameter gamma of a kernel function (POLY / RBF / SIGMOID)");
91 SetParameterFloat(
"classifier.svm.gamma", 1.0);
92 SetParameterDescription(
"classifier.svm.gamma",
"Parameter gamma of a kernel function (POLY / RBF / SIGMOID).");
93 AddParameter(
ParameterType_Float,
"classifier.svm.degree",
"Parameter degree of a kernel function (POLY)");
94 SetParameterFloat(
"classifier.svm.degree", 1.0);
95 SetParameterDescription(
"classifier.svm.degree",
"Parameter degree of a kernel function (POLY).");
97 SetParameterDescription(
"classifier.svm.opt",
98 "SVM parameters optimization flag.\n"
99 "-If set to True, then the optimal SVM parameters will be estimated. "
100 "Parameters are considered optimal by OpenCV when the cross-validation estimate of "
101 "the test set error is minimal. Finally, the SVM training process is computed "
102 "10 times with these optimal parameters over subsets corresponding to 1/10th of "
103 "the training samples using the k-fold cross-validation (with k = 10).\n-If set "
104 "to False, the SVM classification process will be computed once with the "
105 "currently set input SVM parameters over the training samples.\n-Thus, even "
106 "with identical input SVM parameters and a similar random seed, the output "
107 "SVM models will be different according to the method used (optimized or not) "
108 "because the samples are not identically processed within OpenCV.");
111 template <
class TInputValue,
class TOutputValue>
112 void LearningApplicationBase<TInputValue, TOutputValue>::TrainSVM(
typename ListSampleType::Pointer trainingListSample,
113 typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath)
116 typename SVMType::Pointer SVMClassifier = SVMType::New();
117 SVMClassifier->SetRegressionMode(this->m_RegressionFlag);
118 SVMClassifier->SetInputListSample(trainingListSample);
119 SVMClassifier->SetTargetListSample(trainingLabeledListSample);
120 switch (GetParameterInt(
"classifier.svm.k"))
123 SVMClassifier->SetKernelType(CvSVM::LINEAR);
124 std::cout <<
"CvSVM::LINEAR = " << CvSVM::LINEAR << std::endl;
127 SVMClassifier->SetKernelType(CvSVM::RBF);
128 std::cout <<
"CvSVM::RBF = " << CvSVM::RBF << std::endl;
131 SVMClassifier->SetKernelType(CvSVM::POLY);
132 std::cout <<
"CvSVM::POLY = " << CvSVM::POLY << std::endl;
135 SVMClassifier->SetKernelType(CvSVM::SIGMOID);
136 std::cout <<
"CvSVM::SIGMOID = " << CvSVM::SIGMOID << std::endl;
139 SVMClassifier->SetKernelType(CvSVM::LINEAR);
140 std::cout <<
"CvSVM::LINEAR = " << CvSVM::LINEAR << std::endl;
143 if (this->m_RegressionFlag)
145 switch (GetParameterInt(
"classifier.svm.m"))
148 SVMClassifier->SetSVMType(CvSVM::EPS_SVR);
149 std::cout <<
"CvSVM::EPS_SVR = " << CvSVM::EPS_SVR << std::endl;
152 SVMClassifier->SetSVMType(CvSVM::NU_SVR);
153 std::cout <<
"CvSVM::NU_SVR = " << CvSVM::NU_SVR << std::endl;
156 SVMClassifier->SetSVMType(CvSVM::EPS_SVR);
157 std::cout <<
"CvSVM::EPS_SVR = " << CvSVM::EPS_SVR << std::endl;
163 switch (GetParameterInt(
"classifier.svm.m"))
166 SVMClassifier->SetSVMType(CvSVM::C_SVC);
167 std::cout <<
"CvSVM::C_SVC = " << CvSVM::C_SVC << std::endl;
170 SVMClassifier->SetSVMType(CvSVM::NU_SVC);
171 std::cout <<
"CvSVM::NU_SVC = " << CvSVM::NU_SVC << std::endl;
174 SVMClassifier->SetSVMType(CvSVM::ONE_CLASS);
175 std::cout <<
"CvSVM::ONE_CLASS = " << CvSVM::ONE_CLASS << std::endl;
178 SVMClassifier->SetSVMType(CvSVM::C_SVC);
179 std::cout <<
"CvSVM::C_SVC = " << CvSVM::C_SVC << std::endl;
183 SVMClassifier->SetC(GetParameterFloat(
"classifier.svm.c"));
184 SVMClassifier->SetNu(GetParameterFloat(
"classifier.svm.nu"));
185 if (this->m_RegressionFlag)
187 SVMClassifier->SetP(GetParameterFloat(
"classifier.svm.p"));
188 switch (GetParameterInt(
"classifier.svm.term"))
191 SVMClassifier->SetTermCriteriaType(CV_TERMCRIT_ITER);
194 SVMClassifier->SetTermCriteriaType(CV_TERMCRIT_EPS);
197 SVMClassifier->SetTermCriteriaType(CV_TERMCRIT_ITER + CV_TERMCRIT_EPS);
200 SVMClassifier->SetTermCriteriaType(CV_TERMCRIT_ITER);
203 SVMClassifier->SetMaxIter(GetParameterInt(
"classifier.svm.iter"));
204 SVMClassifier->SetEpsilon(GetParameterFloat(
"classifier.svm.eps"));
206 SVMClassifier->SetCoef0(GetParameterFloat(
"classifier.svm.coef0"));
207 SVMClassifier->SetGamma(GetParameterFloat(
"classifier.svm.gamma"));
208 SVMClassifier->SetDegree(GetParameterFloat(
"classifier.svm.degree"));
209 SVMClassifier->SetParameterOptimization(GetParameterInt(
"classifier.svm.opt"));
210 SVMClassifier->Train();
211 SVMClassifier->Save(modelPath);
214 SetParameterFloat(
"classifier.svm.c",
static_cast<float>(SVMClassifier->GetOutputC()));
215 SetParameterFloat(
"classifier.svm.nu",
static_cast<float>(SVMClassifier->GetOutputNu()));
216 if (this->m_RegressionFlag)
218 SetParameterFloat(
"classifier.svm.p",
static_cast<float>(SVMClassifier->GetOutputP()));
220 SetParameterFloat(
"classifier.svm.coef0",
static_cast<float>(SVMClassifier->GetOutputCoef0()));
221 SetParameterFloat(
"classifier.svm.gamma",
static_cast<float>(SVMClassifier->GetOutputGamma()));
222 SetParameterFloat(
"classifier.svm.degree",
static_cast<float>(SVMClassifier->GetOutputDegree()));