Orfeo ToolBox  4.2
Orfeo ToolBox is not a black box
otbSVMModel.txx
Go to the documentation of this file.
1 /*=========================================================================
2 
3  Program: ORFEO Toolbox
4  Language: C++
5  Date: $Date$
6  Version: $Revision$
7 
8 
9  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
10  See OTBCopyright.txt for details.
11 
12 
13  This software is distributed WITHOUT ANY WARRANTY; without even
14  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
15  PURPOSE. See the above copyright notices for more information.
16 
17 =========================================================================*/
18 #ifndef __otbSVMModel_txx
19 #define __otbSVMModel_txx
20 #include "otbSVMModel.h"
21 #include "otbMacro.h"
22 
23 #include <algorithm>
24 
25 namespace otb
26 {
27 // TODO: Check memory allocation in this class
28 template <class TValue, class TLabel>
30 {
31  // Default parameters
32  this->SetSVMType(C_SVC);
33  this->SetKernelType(LINEAR);
34  this->SetPolynomialKernelDegree(3);
35  this->SetKernelGamma(1.); // 1/k
36  this->SetKernelCoef0(1.);
37  this->SetKernelFunctor(NULL);
38  this->SetNu(0.5);
39  this->SetCacheSize(40);
40  this->SetC(1);
41  this->SetEpsilon(1e-3);
42  this->SetP(0.1);
43  this->DoShrinking(true);
44  this->DoProbabilityEstimates(false);
45 
46  m_Parameters.kernel_generic = NULL;
47  m_Parameters.kernel_composed = NULL;
48  m_Parameters.nr_weight = 0;
49  m_Parameters.weight_label = NULL;
50  m_Parameters.weight = NULL;
51 
52  m_Model = NULL;
53 
54  this->Initialize();
55 }
56 
57 template <class TValue, class TLabel>
59 {
60  this->DeleteModel();
61  this->DeleteProblem();
62 }
63 template <class TValue, class TLabel>
64 void
66 {
67  // Initialize model
68  if (!m_Model)
69  {
70  m_Model = new struct svm_model;
71  m_Model->l = 0;
72  m_Model->nr_class = 0;
73  m_Model->SV = NULL;
74  m_Model->sv_coef = NULL;
75  m_Model->rho = NULL;
76  m_Model->label = NULL;
77  m_Model->probA = NULL;
78  m_Model->probB = NULL;
79  m_Model->nSV = NULL;
80 
81  m_ModelUpToDate = false;
82 
83  }
84 
85  // Intialize problem
86  m_Problem.l = 0;
87  m_Problem.y = NULL;
88  m_Problem.x = NULL;
89 
90  m_ProblemUpToDate = false;
91 }
92 
93 template <class TValue, class TLabel>
94 void
96 {
97  this->DeleteProblem();
98  this->DeleteModel();
99 
100  // Clear samples
101  m_Samples.clear();
102 
103  // Initialize values
104  this->Initialize();
105 }
106 
107 template <class TValue, class TLabel>
108 void
110 {
111  svm_free_and_destroy_model(&m_Model);
112  m_Model = NULL;
113 }
114 
115 template <class TValue, class TLabel>
116 void
118 {
119 // Deallocate any existing problem
120  if (m_Problem.y)
121  {
122  delete[] m_Problem.y;
123  m_Problem.y = NULL;
124  }
125 
126  if (m_Problem.x)
127  {
128  for (int i = 0; i < m_Problem.l; ++i)
129  {
130  if (m_Problem.x[i])
131  {
132  delete[] m_Problem.x[i];
133  }
134  }
135  delete[] m_Problem.x;
136  m_Problem.x = NULL;
137  }
138  m_Problem.l = 0;
139  m_ProblemUpToDate = false;
140 }
141 
142 template <class TValue, class TLabel>
143 void
145 {
146  SampleType newSample(measure, label);
147  m_Samples.push_back(newSample);
148  m_ProblemUpToDate = false;
149 }
150 
151 template <class TValue, class TLabel>
152 void
154 {
155  m_Samples.clear();
156  m_ProblemUpToDate = false;
157 }
158 
159 template <class TValue, class TLabel>
160 void
162 {
163  m_Samples = samples;
164  m_ProblemUpToDate = false;
165 }
166 
167 template <class TValue, class TLabel>
168 void
170 {
171  // Check if problem is up-to-date
172  if (m_ProblemUpToDate)
173  {
174  return;
175  }
176 
177  // Get number of samples
178  int probl = m_Samples.size();
179 
180  if (probl < 1)
181  {
182  itkExceptionMacro(<< "No samples, can not build SVM problem.");
183  }
184  otbMsgDebugMacro(<< "Rebuilding problem ...");
185 
186  // Get the size of the samples
187  long int elements = m_Samples[0].first.size() + 1;
188 
189  // Deallocate any previous problem
190  this->DeleteProblem();
191 
192  // Allocate the problem
193  m_Problem.l = probl;
194  m_Problem.y = new double[probl];
195  m_Problem.x = new struct svm_node*[probl];
196 
197  for (int i = 0; i < probl; ++i)
198  {
199  // Initialize labels to 0
200  m_Problem.y[i] = 0;
201  m_Problem.x[i] = new struct svm_node[elements];
202 
203  // Intialize elements (value = 0; index = -1)
204  for (unsigned int j = 0; j < static_cast<unsigned int>(elements); ++j)
205  {
206  m_Problem.x[i][j].index = -1;
207  m_Problem.x[i][j].value = 0;
208  }
209  }
210 
211  // Iterate on the samples
212  typename SamplesVectorType::const_iterator sIt = m_Samples.begin();
213  int sampleIndex = 0;
214  int maxElementIndex = 0;
215 
216  while (sIt != m_Samples.end())
217  {
218 
219  // Get the sample measurement and label
220  MeasurementType measure = sIt->first;
221  LabelType label = sIt->second;
222 
223  // Set the label
224  m_Problem.y[sampleIndex] = label;
225 
226  int elementIndex = 0;
227 
228  // Populate the svm nodes
229  for (typename MeasurementType::const_iterator eIt = measure.begin();
230  eIt != measure.end() && elementIndex < elements; ++eIt, ++elementIndex)
231  {
232  m_Problem.x[sampleIndex][elementIndex].index = elementIndex + 1;
233  m_Problem.x[sampleIndex][elementIndex].value = (*eIt);
234  }
235 
236  // Get the max index
237  if (elementIndex > maxElementIndex)
238  {
239  maxElementIndex = elementIndex;
240  }
241 
242  ++sampleIndex;
243  ++sIt;
244  }
245 
246  // Compute the kernel gamma from maxElementIndex if necessary
247  if (this->GetKernelGamma() == 0
248  && this->GetParameters().kernel_type != COMPOSED
249  && this->GetParameters().kernel_type != GENERIC) this->SetKernelGamma(1.0 / static_cast<double>(maxElementIndex));
250 
251  // problem is up-to-date
252  m_ProblemUpToDate = true;
253 }
254 
255 template <class TValue, class TLabel>
256 double
258 {
259  // Build problem
260  this->BuildProblem();
261 
262  // Check consistency
263  this->ConsistencyCheck();
264 
265  // Get the length of the problem
266  int length = m_Problem.l;
267 
268  // Temporary memory to store cross validation results
269  double *target = new double[length];
270 
271  // Do cross validation
272  svm_cross_validation(&m_Problem, &m_Parameters, nbFolders, target);
273 
274  // Evaluate accuracy
275  int i;
276  double total_correct = 0.;
277 
278  for (i = 0; i < length; ++i)
279  {
280  if (target[i] == m_Problem.y[i])
281  {
282  ++total_correct;
283  }
284  }
285  double accuracy = total_correct / length;
286 
287  // Free temporary memory
288  delete[] target;
289 
290  // return accuracy value
291  return accuracy;
292 }
293 
294 template <class TValue, class TLabel>
295 void
297 {
298  if (m_Parameters.svm_type == ONE_CLASS && this->GetDoProbabilityEstimates())
299  {
300  otbMsgDebugMacro(<< "Disabling SVM probability estimates for ONE_CLASS SVM type.");
301  this->DoProbabilityEstimates(false);
302  }
303 
304  const char* error_msg = svm_check_parameter(&m_Problem, &m_Parameters);
305 
306  if (error_msg)
307  {
308  throw itk::ExceptionObject(__FILE__, __LINE__, error_msg, ITK_LOCATION);
309  }
310 }
311 
312 template <class TValue, class TLabel>
313 void
315 {
316  // If the model is already up-to-date, return
317  if (m_ModelUpToDate)
318  {
319  return;
320  }
321 
322  // Build problem
323  this->BuildProblem();
324 
325  // Check consistency
326  this->ConsistencyCheck();
327 
328  // train the model
329  m_Model = svm_train(&m_Problem, &m_Parameters);
330 
331  // Set the model as up-to-date
332  m_ModelUpToDate = true;
333 }
334 
335 template <class TValue, class TLabel>
338 {
339  // Check if model is up-to-date
340  if (!m_ModelUpToDate)
341  {
342  itkExceptionMacro(<< "Model is not up-to-date, can not predict label");
343  }
344 
345  // Check probability prediction
346  bool predict_probability = svm_check_probability_model(m_Model);
347 
348  if (this->GetSVMType() == ONE_CLASS)
349  {
350  predict_probability = 0;
351  }
352 
353  // Get type and number of classes
354  int svm_type = svm_get_svm_type(m_Model);
355  int nr_class = svm_get_nr_class(m_Model);
356 
357  // Allocate space for labels
358  double *prob_estimates = NULL;
359 
360  // Eventually allocate space for probabilities
361  if (predict_probability)
362  {
363  if (svm_type == NU_SVR || svm_type == EPSILON_SVR)
364  {
366  <<
367  "Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma), sigma="
368  << svm_get_svr_probability(m_Model));
369  }
370  else
371  {
372  prob_estimates = new double[nr_class];
373  }
374  }
375 
376  // Allocate nodes (/TODO if performances problems are related to too
377  // many allocations, a cache approach can be set)
378  struct svm_node * x = new struct svm_node[measure.size() + 1];
379 
380  int valueIndex = 0;
381 
382  // Fill the node
383  for (typename MeasurementType::const_iterator mIt = measure.begin(); mIt != measure.end(); ++mIt, ++valueIndex)
384  {
385  x[valueIndex].index = valueIndex + 1;
386  x[valueIndex].value = (*mIt);
387  }
388 
389  // terminate node
390  x[measure.size()].index = -1;
391  x[measure.size()].value = 0;
392 
393  LabelType label = 0;
394 
395  if (predict_probability && (svm_type == C_SVC || svm_type == NU_SVC))
396  {
397  label = static_cast<LabelType>(svm_predict_probability(m_Model, x, prob_estimates));
398  }
399  else
400  {
401  label = static_cast<LabelType>(svm_predict(m_Model, x));
402  }
403 
404  // Free allocated memory
405  delete[] x;
406 
407  if (prob_estimates)
408  {
409  delete[] prob_estimates;
410  }
411 
412  return label;
413 }
414 
415 template <class TValue, class TLabel>
418 {
419  // Check if model is up-to-date
420  if (!m_ModelUpToDate)
421  {
422  itkExceptionMacro(<< "Model is not up-to-date, can not predict label");
423  }
424 
425  // Allocate nodes (/TODO if performances problems are related to too
426  // many allocations, a cache approach can be set)
427  struct svm_node * x = new struct svm_node[measure.size() + 1];
428 
429  int valueIndex = 0;
430 
431  // Fill the node
432  for (typename MeasurementType::const_iterator mIt = measure.begin(); mIt != measure.end(); ++mIt, ++valueIndex)
433  {
434  x[valueIndex].index = valueIndex + 1;
435  x[valueIndex].value = (*mIt);
436  }
437 
438  // terminate node
439  x[measure.size()].index = -1;
440  x[measure.size()].value = 0;
441 
442  // Intialize distances vector
443  DistancesVectorType distances(m_Model->nr_class*(m_Model->nr_class - 1) / 2);
444 
445  // predict distances vector
446  svm_predict_values(m_Model, x, (double*) (distances.GetDataPointer()));
447 
448  // Free allocated memory
449  delete[] x;
450 
451  return (distances);
452 }
453 
454 template <class TValue, class TLabel>
457 {
458  // Check if model is up-to-date
459  if (!m_ModelUpToDate)
460  {
461  itkExceptionMacro(<< "Model is not up-to-date, can not predict probabilities");
462  }
463 
464  if (svm_check_probability_model(m_Model) == 0)
465  {
466  throw itk::ExceptionObject(__FILE__, __LINE__,
467  "Model does not support probability estimates", ITK_LOCATION);
468  }
469 
470  // Get number of classes
471  int nr_class = svm_get_nr_class(m_Model);
472 
473  // Allocate nodes (/TODO if performances problems are related to too
474  // many allocations, a cache approach can be set)
475  struct svm_node * x = new struct svm_node[measure.size() + 1];
476 
477  int valueIndex = 0;
478 
479  // Fill the node
480  for (typename MeasurementType::const_iterator mIt = measure.begin(); mIt != measure.end(); ++mIt, ++valueIndex)
481  {
482  x[valueIndex].index = valueIndex + 1;
483  x[valueIndex].value = (*mIt);
484  }
485 
486  // Termination node
487  x[measure.size()].index = -1;
488  x[measure.size()].value = 0;
489 
490  double* dec_values = new double[nr_class];
491  svm_predict_probability(m_Model, x, dec_values);
492 
493  // Reorder values in increasing class label
494  int* labels = m_Model->label;
495  std::vector<int> orderedLabels(nr_class);
496  std::copy(labels, labels + nr_class, orderedLabels.begin());
497  std::sort(orderedLabels.begin(), orderedLabels.end());
498 
499  ProbabilitiesVectorType probabilities(nr_class);
500  for (int i = 0; i < nr_class; ++i)
501  {
502  // svm_predict_probability is such that "dec_values[i]" corresponds to label "labels[i]"
503  std::vector<int>::iterator it = std::find(orderedLabels.begin(), orderedLabels.end(), labels[i]);
504  probabilities[it - orderedLabels.begin()] = dec_values[i];
505  }
506 
507  // Free allocated memory
508  delete[] x;
509  delete[] dec_values;
510 
511  return probabilities;
512 }
513 
514 template <class TValue, class TLabel>
515 void
516 SVMModel<TValue, TLabel>::SetModel(struct svm_model* aModel)
517 {
518  this->DeleteModel();
519  m_Model = svm_copy_model(aModel);
520  m_ModelUpToDate = true;
521 }
522 
523 template <class TValue, class TLabel>
524 void
525 SVMModel<TValue, TLabel>::SaveModel(const char* model_file_name) const
526 {
527  if (svm_save_model(model_file_name, m_Model) != 0)
528  {
529  itkExceptionMacro(<< "Problem while saving SVM model "
530  << std::string(model_file_name));
531  }
532 }
533 
534 template <class TValue, class TLabel>
535 void
536 SVMModel<TValue, TLabel>::LoadModel(const char* model_file_name)
537 {
538  this->DeleteModel();
539  m_Model = svm_load_model(model_file_name, m_Parameters.kernel_generic);
540  if (m_Model == 0)
541  {
542  itkExceptionMacro(<< "Problem while loading SVM model "
543  << std::string(model_file_name));
544  }
545  m_Parameters = m_Model->param;
546  m_ModelUpToDate = true;
547 }
548 
549 template <class TValue, class TLabel>
552 {
553  Pointer modelCopy = New();
554  modelCopy->SetModel(m_Model);
555  // We do not copy the problem to avoid sharing allocated memory
556  return modelCopy;
557 }
558 
559 template <class TValue, class TLabel>
560 void
561 SVMModel<TValue, TLabel>::PrintSelf(std::ostream& os, itk::Indent indent) const
562 {
563  Superclass::PrintSelf(os, indent);
564 }
565 
566 template <class TValue, class TLabel>
567 void
568 SVMModel<TValue, TLabel>::SetSupportVectors(svm_node ** sv, int nbOfSupportVector)
569 {
570  // TODO: rewrite this to check memory allocation
571 
572  // erase the old SV
573  // delete just the first element, it destoyes the whole pointers (cf SV filling with x_space)
574  delete[] (m_Model->SV[0]);
575 
576  for (int n = 0; n < m_Model->l; ++n)
577  {
578  m_Model->SV[n] = NULL;
579  }
580  delete[] (m_Model->SV);
581  m_Model->SV = NULL;
582 
583  m_Model->SV = new struct svm_node*[m_Model->l];
584 
585  // copy new SV values
586  svm_node **SV = m_Model->SV;
587 
588  // Compute the total number of SV elements.
589  unsigned int elements = 0;
590  for (int p = 0; p < nbOfSupportVector; ++p)
591  {
592  //std::cout << p << " ";
593  const svm_node *tempNode = sv[p];
594  //std::cout << p << " ";
595  while (tempNode->index != -1)
596  {
597  tempNode++;
598  ++elements;
599  }
600  ++elements; // for -1 values
601  }
602 
603  if (m_Model->l > 0)
604  {
605  SV[0] = new struct svm_node[elements];
606  memcpy(SV[0], sv[0], sizeof(svm_node*) * elements);
607  }
608  svm_node *x_space = SV[0];
609 
610  int j = 0;
611  for (int i = 0; i < m_Model->l; ++i)
612  {
613  // SV
614  SV[i] = &x_space[j];
615  const svm_node *p = sv[i];
616  svm_node * pCpy = SV[i];
617  while (p->index != -1)
618  {
619  pCpy->index = p->index;
620  pCpy->value = p->value;
621  ++p;
622  ++pCpy;
623  ++j;
624  }
625  pCpy->index = -1;
626  ++j;
627  }
628 
629  if (m_Model->l > 0)
630  {
631  delete[] SV[0];
632  }
633 }
634 
635 template <class TValue, class TLabel>
636 void
637 SVMModel<TValue, TLabel>::SetAlpha(double ** alpha, int nbOfSupportVector)
638 {
639  // TODO: Check memory allocation
640 
641  // Erase the old sv_coef
642  for (int i = 0; i < m_Model->nr_class - 1; ++i)
643  {
644  delete[] m_Model->sv_coef[i];
645  }
646  delete[] m_Model->sv_coef;
647 
648  // copy new sv_coef values
649  m_Model->sv_coef = new double*[m_Model->nr_class - 1];
650  for (int i = 0; i < m_Model->nr_class - 1; ++i)
651  m_Model->sv_coef[i] = new double[m_Model->l];
652 
653  for (int i = 0; i < m_Model->l; ++i)
654  {
655  // sv_coef
656  for (int k = 0; k < m_Model->nr_class - 1; ++k)
657  {
658  m_Model->sv_coef[k][i] = alpha[k][i];
659  }
660  }
661 }
662 
663 } // end namespace otb
664 
665 #endif
void SaveModel(const char *model_file_name) const
DistancesVectorType EvaluateHyperplanesDistances(const MeasurementType &measure) const
void SetSamples(const SamplesVectorType &samples)
void ClearSamples()
void AddSample(const MeasurementType &measure, const LabelType &label)
void BuildProblem()
void Initialize()
Definition: otbSVMModel.txx:65
void LoadModel(const char *model_file_name)
void ConsistencyCheck()
void SetModel(struct svm_model *aModel)
const TValue * GetDataPointer() const
#define otbMsgDebugMacro(x)
Definition: otbMacro.h:54
void SetAlpha(double **alpha, int nbOfSupportVector)
LabelType EvaluateLabel(const MeasurementType &measure) const
TLabel LabelType
Definition: otbSVMModel.h:71
double CrossValidation(unsigned int nbFolders)
void DeleteProblem()
ProbabilitiesVectorType EvaluateProbabilities(const MeasurementType &measure) const
void PrintSelf(std::ostream &os, itk::Indent indent) const
#define NULL
Pointer GetCopy() const
std::vector< SampleType > SamplesVectorType
Definition: otbSVMModel.h:74
std::vector< ValueType > MeasurementType
Definition: otbSVMModel.h:72
std::pair< MeasurementType, LabelType > SampleType
Definition: otbSVMModel.h:73
void SetSupportVectors(svm_node **sv, int nbOfSupportVector)
#define otbMsgDevMacro(x)
Definition: otbMacro.h:94
void DeleteModel()
virtual ~SVMModel()
Definition: otbSVMModel.txx:58