21 #ifndef otbTrainLibSVM_hxx
22 #define otbTrainLibSVM_hxx
31 template <
class TInputValue,
class TOutputValue>
32 void LearningApplicationBase<TInputValue, TOutputValue>::InitLibSVMParams()
34 AddChoice(
"classifier.libsvm",
"LibSVM classifier");
35 SetParameterDescription(
"classifier.libsvm",
"This group of parameters allows setting SVM classifier parameters.");
37 AddChoice(
"classifier.libsvm.k.linear",
"Linear");
38 SetParameterDescription(
"classifier.libsvm.k.linear",
"Linear Kernel, no mapping is done, this is the fastest option.");
40 AddChoice(
"classifier.libsvm.k.rbf",
"Gaussian radial basis function");
41 SetParameterDescription(
"classifier.libsvm.k.rbf",
42 "This kernel is a good choice in most of the case. It is "
43 "an exponential function of the euclidean distance between "
46 AddChoice(
"classifier.libsvm.k.poly",
"Polynomial");
47 SetParameterDescription(
"classifier.libsvm.k.poly",
"Polynomial Kernel, the mapping is a polynomial function.");
49 AddChoice(
"classifier.libsvm.k.sigmoid",
"Sigmoid");
50 SetParameterDescription(
"classifier.libsvm.k.sigmoid",
"The kernel is a hyperbolic tangente function of the vectors.");
52 SetParameterString(
"classifier.libsvm.k",
"linear");
53 SetParameterDescription(
"classifier.libsvm.k",
"SVM Kernel Type.");
55 SetParameterDescription(
"classifier.libsvm.m",
"Type of SVM formulation.");
56 if (this->m_RegressionFlag)
58 AddChoice(
"classifier.libsvm.m.epssvr",
"Epsilon Support Vector Regression");
59 SetParameterDescription(
"classifier.libsvm.m.epssvr",
60 "The distance between feature vectors from the training set and the "
61 "fitting hyper-plane must be less than Epsilon. For outliers the penalty "
62 "multiplier C is used ");
64 AddChoice(
"classifier.libsvm.m.nusvr",
"Nu Support Vector Regression");
65 SetParameterString(
"classifier.libsvm.m",
"epssvr");
66 SetParameterDescription(
"classifier.libsvm.m.nusvr",
67 "Same as the epsilon regression except that this time the bounded "
68 "parameter nu is used instead of epsilon");
72 AddChoice(
"classifier.libsvm.m.csvc",
"C support vector classification");
73 SetParameterDescription(
"classifier.libsvm.m.csvc",
74 "This formulation allows imperfect separation of classes. The penalty "
75 "is set through the cost parameter C.");
77 AddChoice(
"classifier.libsvm.m.nusvc",
"Nu support vector classification");
78 SetParameterDescription(
"classifier.libsvm.m.nusvc",
79 "This formulation allows imperfect separation of classes. The penalty "
80 "is set through the cost parameter Nu. As compared to C, Nu is harder "
81 "to optimize, and may not be as fast.");
83 AddChoice(
"classifier.libsvm.m.oneclass",
"Distribution estimation (One Class SVM)");
84 SetParameterDescription(
"classifier.libsvm.m.oneclass",
85 "All the training data are from the same class, SVM builds a boundary "
86 "that separates the class from the rest of the feature space.");
87 SetParameterString(
"classifier.libsvm.m",
"csvc");
91 SetParameterFloat(
"classifier.libsvm.c", 1.0);
92 SetParameterDescription(
"classifier.libsvm.c",
93 "SVM models have a cost parameter C (1 by default) to control the "
94 "trade-off between training errors and forcing rigid margins.");
97 SetParameterFloat(
"classifier.libsvm.gamma", 1.0);
98 SetMinimumParameterFloatValue(
"classifier.libsvm.gamma", 0.0);
99 SetParameterDescription(
"classifier.libsvm.gamma",
"Set gamma parameter in poly/rbf/sigmoid kernel function");
102 SetParameterFloat(
"classifier.libsvm.coef0", 0.0);
103 SetParameterDescription(
"classifier.libsvm.coef0",
"Set coef0 parameter in poly/sigmoid kernel function");
106 SetParameterInt(
"classifier.libsvm.degree", 3);
107 SetMinimumParameterIntValue(
"classifier.libsvm.degree", 1);
108 SetParameterDescription(
"classifier.libsvm.degree",
"Set polynomial degree in poly kernel function");
111 SetParameterFloat(
"classifier.libsvm.nu", 0.5);
112 SetParameterDescription(
"classifier.libsvm.nu",
113 "Cost parameter Nu, in the range 0..1, the larger the value, "
114 "the smoother the decision.");
118 SetParameterDescription(
"classifier.libsvm.opt",
"SVM parameters optimization flag.");
121 SetParameterDescription(
"classifier.libsvm.prob",
"Probability estimation flag.");
123 if (this->m_RegressionFlag)
126 SetParameterFloat(
"classifier.libsvm.eps", 1e-3);
127 SetParameterDescription(
"classifier.libsvm.eps",
128 "The distance between feature vectors from the training set and "
129 "the fitting hyper-plane must be less than Epsilon. For outliers"
130 "the penalty multiplier is set by C.");
134 template <
class TInputValue,
class TOutputValue>
135 void LearningApplicationBase<TInputValue, TOutputValue>::TrainLibSVM(
typename ListSampleType::Pointer trainingListSample,
136 typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath)
139 typename LibSVMType::Pointer libSVMClassifier = LibSVMType::New();
140 libSVMClassifier->SetRegressionMode(this->m_RegressionFlag);
141 libSVMClassifier->SetInputListSample(trainingListSample);
142 libSVMClassifier->SetTargetListSample(trainingLabeledListSample);
145 libSVMClassifier->SetParameterOptimization(GetParameterInt(
"classifier.libsvm.opt"));
146 libSVMClassifier->SetDoProbabilityEstimates(GetParameterInt(
"classifier.libsvm.prob"));
147 libSVMClassifier->SetNu(GetParameterFloat(
"classifier.libsvm.nu"));
148 libSVMClassifier->SetC(GetParameterFloat(
"classifier.libsvm.c"));
150 switch (GetParameterInt(
"classifier.libsvm.k"))
153 libSVMClassifier->SetKernelType(LINEAR);
156 libSVMClassifier->SetKernelType(RBF);
157 libSVMClassifier->SetKernelGamma(GetParameterFloat(
"classifier.libsvm.gamma"));
160 libSVMClassifier->SetKernelType(POLY);
161 libSVMClassifier->SetKernelGamma(GetParameterFloat(
"classifier.libsvm.gamma"));
162 libSVMClassifier->SetKernelCoef0(GetParameterFloat(
"classifier.libsvm.coef0"));
163 libSVMClassifier->SetPolynomialKernelDegree(GetParameterInt(
"classifier.libsvm.degree"));
166 libSVMClassifier->SetKernelType(SIGMOID);
167 libSVMClassifier->SetKernelGamma(GetParameterFloat(
"classifier.libsvm.gamma"));
168 libSVMClassifier->SetKernelCoef0(GetParameterFloat(
"classifier.libsvm.coef0"));
171 libSVMClassifier->SetKernelType(LINEAR);
174 if (this->m_RegressionFlag)
176 switch (GetParameterInt(
"classifier.libsvm.m"))
179 libSVMClassifier->SetSVMType(EPSILON_SVR);
182 libSVMClassifier->SetSVMType(NU_SVR);
185 libSVMClassifier->SetSVMType(EPSILON_SVR);
188 libSVMClassifier->SetEpsilon(GetParameterFloat(
"classifier.libsvm.eps"));
192 switch (GetParameterInt(
"classifier.libsvm.m"))
195 libSVMClassifier->SetSVMType(C_SVC);
198 libSVMClassifier->SetSVMType(NU_SVC);
201 libSVMClassifier->SetSVMType(ONE_CLASS);
204 libSVMClassifier->SetSVMType(C_SVC);
210 libSVMClassifier->Train();
211 libSVMClassifier->Save(modelPath);