18 #ifndef __otbSVMModel_txx
19 #define __otbSVMModel_txx
27 template <
class TValue,
class TLabel>
31 this->SetSVMType(
C_SVC);
32 this->SetKernelType(
LINEAR);
33 this->SetPolynomialKernelDegree(3);
34 this->SetKernelGamma(1.);
35 this->SetKernelCoef0(1.);
36 this->SetKernelFunctor(
NULL);
38 this->SetCacheSize(40);
40 this->SetEpsilon(1e-3);
42 this->DoShrinking(
true);
43 this->DoProbabilityEstimates(
false);
45 m_Parameters.kernel_generic =
NULL;
46 m_Parameters.kernel_composed =
NULL;
47 m_Parameters.nr_weight = 0;
48 m_Parameters.weight_label =
NULL;
49 m_Parameters.weight =
NULL;
56 template <
class TValue,
class TLabel>
60 this->DeleteProblem();
62 template <
class TValue,
class TLabel>
71 m_Model->nr_class = 0;
73 m_Model->sv_coef =
NULL;
75 m_Model->label =
NULL;
76 m_Model->probA =
NULL;
77 m_Model->probB =
NULL;
80 m_ModelUpToDate =
false;
89 m_ProblemUpToDate =
false;
92 template <
class TValue,
class TLabel>
96 this->DeleteProblem();
106 template <
class TValue,
class TLabel>
114 template <
class TValue,
class TLabel>
121 delete[] m_Problem.y;
127 for (
int i = 0; i < m_Problem.l; ++i)
131 delete[] m_Problem.x[i];
134 delete[] m_Problem.x;
138 m_ProblemUpToDate =
false;
141 template <
class TValue,
class TLabel>
146 m_Samples.push_back(newSample);
147 m_ProblemUpToDate =
false;
150 template <
class TValue,
class TLabel>
155 m_ProblemUpToDate =
false;
158 template <
class TValue,
class TLabel>
163 m_ProblemUpToDate =
false;
166 template <
class TValue,
class TLabel>
171 if (m_ProblemUpToDate)
177 int probl = m_Samples.size();
181 itkExceptionMacro(<<
"No samples, can not build SVM problem.");
186 long int elements = m_Samples[0].first.size() + 1;
189 this->DeleteProblem();
193 m_Problem.y =
new double[probl];
194 m_Problem.x =
new struct svm_node*[probl];
196 for (
int i = 0; i < probl; ++i)
200 m_Problem.x[i] =
new struct svm_node[elements];
203 for (
unsigned int j = 0; j < static_cast<unsigned int>(elements); ++j)
205 m_Problem.x[i][j].
index = -1;
206 m_Problem.x[i][j].value = 0;
211 typename SamplesVectorType::const_iterator sIt = m_Samples.begin();
213 int maxElementIndex = 0;
215 while (sIt != m_Samples.end())
223 m_Problem.y[sampleIndex] = label;
225 int elementIndex = 0;
228 for (
typename MeasurementType::const_iterator eIt = measure.begin();
229 eIt != measure.end() && elementIndex < elements; ++eIt, ++elementIndex)
231 m_Problem.x[sampleIndex][elementIndex].index = elementIndex + 1;
232 m_Problem.x[sampleIndex][elementIndex].value = (*eIt);
236 if (elementIndex > maxElementIndex)
238 maxElementIndex = elementIndex;
246 if (this->GetKernelGamma() == 0
247 && this->GetParameters().kernel_type !=
COMPOSED
248 && this->GetParameters().kernel_type !=
GENERIC) this->SetKernelGamma(1.0 / static_cast<double>(maxElementIndex));
251 m_ProblemUpToDate =
true;
254 template <
class TValue,
class TLabel>
259 this->BuildProblem();
262 this->ConsistencyCheck();
265 int length = m_Problem.l;
268 double *target =
new double[length];
275 double total_correct = 0.;
277 for (i = 0; i < length; ++i)
279 if (target[i] == m_Problem.y[i])
284 double accuracy = total_correct / length;
293 template <
class TValue,
class TLabel>
297 if (m_Parameters.svm_type ==
ONE_CLASS && this->GetDoProbabilityEstimates())
299 otbMsgDebugMacro(<<
"Disabling SVM probability estimates for ONE_CLASS SVM type.");
300 this->DoProbabilityEstimates(
false);
311 template <
class TValue,
class TLabel>
322 this->BuildProblem();
325 this->ConsistencyCheck();
328 m_Model =
svm_train(&m_Problem, &m_Parameters);
331 m_ModelUpToDate =
true;
334 template <
class TValue,
class TLabel>
339 if (!m_ModelUpToDate)
341 itkExceptionMacro(<<
"Model is not up-to-date, can not predict label");
349 predict_probability = 0;
357 double *prob_estimates =
NULL;
360 if (predict_probability)
366 "Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma), sigma="
371 prob_estimates =
new double[nr_class];
382 for (
typename MeasurementType::const_iterator mIt = measure.begin(); mIt != measure.end(); ++mIt, ++valueIndex)
384 x[valueIndex].
index = valueIndex + 1;
385 x[valueIndex].
value = (*mIt);
389 x[measure.size()].
index = -1;
390 x[measure.size()].
value = 0;
394 if (predict_probability && (svm_type ==
C_SVC || svm_type ==
NU_SVC))
408 delete[] prob_estimates;
414 template <
class TValue,
class TLabel>
419 if (!m_ModelUpToDate)
421 itkExceptionMacro(<<
"Model is not up-to-date, can not predict label");
431 for (
typename MeasurementType::const_iterator mIt = measure.begin(); mIt != measure.end(); ++mIt, ++valueIndex)
433 x[valueIndex].
index = valueIndex + 1;
434 x[valueIndex].
value = (*mIt);
438 x[measure.size()].
index = -1;
439 x[measure.size()].
value = 0;
453 template <
class TValue,
class TLabel>
458 if (!m_ModelUpToDate)
460 itkExceptionMacro(<<
"Model is not up-to-date, can not predict probabilities");
466 "Model does not support probability estimates", ITK_LOCATION);
479 for (
typename MeasurementType::const_iterator mIt = measure.begin(); mIt != measure.end(); ++mIt, ++valueIndex)
481 x[valueIndex].
index = valueIndex + 1;
482 x[valueIndex].
value = (*mIt);
486 x[measure.size()].
index = -1;
487 x[measure.size()].
value = 0;
489 double* dec_values =
new double[nr_class];
493 int* labels = m_Model->label;
494 std::vector<int> orderedLabels(nr_class);
495 std::copy(labels, labels + nr_class, orderedLabels.begin());
496 std::sort(orderedLabels.begin(), orderedLabels.end());
499 for (
int i = 0; i < nr_class; ++i)
502 std::vector<int>::iterator it = std::find(orderedLabels.begin(), orderedLabels.end(), labels[i]);
503 probabilities[it - orderedLabels.begin()] = dec_values[i];
510 return probabilities;
513 template <
class TValue,
class TLabel>
519 m_ModelUpToDate =
true;
522 template <
class TValue,
class TLabel>
528 itkExceptionMacro(<<
"Problem while saving SVM model "
529 << std::string(model_file_name));
533 template <
class TValue,
class TLabel>
538 m_Model =
svm_load_model(model_file_name, m_Parameters.kernel_generic);
541 itkExceptionMacro(<<
"Problem while loading SVM model "
542 << std::string(model_file_name));
544 m_Parameters = m_Model->param;
545 m_ModelUpToDate =
true;
548 template <
class TValue,
class TLabel>
553 modelCopy->SetModel(m_Model);
558 template <
class TValue,
class TLabel>
562 Superclass::PrintSelf(os, indent);
565 template <
class TValue,
class TLabel>
573 delete[] (m_Model->SV[0]);
575 for (
int n = 0; n < m_Model->l; ++n)
577 m_Model->SV[n] =
NULL;
579 delete[] (m_Model->SV);
582 m_Model->SV =
new struct svm_node*[m_Model->l];
588 unsigned int elements = 0;
589 for (
int p = 0; p < nbOfSupportVector; ++p)
594 while (tempNode->
index != -1)
604 SV[0] =
new struct svm_node[elements];
605 memcpy(SV[0], sv[0],
sizeof(
svm_node*) * elements);
610 for (
int i = 0; i < m_Model->l; ++i)
616 while (p->
index != -1)
634 template <
class TValue,
class TLabel>
641 for (
int i = 0; i < m_Model->nr_class - 1; ++i)
643 delete[] m_Model->sv_coef[i];
645 delete[] m_Model->sv_coef;
648 m_Model->sv_coef =
new double*[m_Model->nr_class - 1];
649 for (
int i = 0; i < m_Model->nr_class - 1; ++i)
650 m_Model->sv_coef[i] =
new double[m_Model->l];
652 for (
int i = 0; i < m_Model->l; ++i)
655 for (
int k = 0; k < m_Model->nr_class - 1; ++k)
657 m_Model->sv_coef[k][i] = alpha[k][i];