Orfeo Toolbox  4.2
otbSVMModel.txx
Go to the documentation of this file.
1 /*=========================================================================
2 
3  Program: ORFEO Toolbox
4  Language: C++
5  Date: $Date$
6  Version: $Revision$
7 
8 
9  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
10  See OTBCopyright.txt for details.
11 
12 
13  This software is distributed WITHOUT ANY WARRANTY; without even
14  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
15  PURPOSE. See the above copyright notices for more information.
16 
17 =========================================================================*/
18 #ifndef __otbSVMModel_txx
19 #define __otbSVMModel_txx
20 #include "otbSVMModel.h"
21 #include "otbMacro.h"
22 
23 
24 namespace otb
25 {
26 // TODO: Check memory allocation in this class
27 template <class TValue, class TLabel>
29 {
30  // Default parameters
31  this->SetSVMType(C_SVC);
32  this->SetKernelType(LINEAR);
33  this->SetPolynomialKernelDegree(3);
34  this->SetKernelGamma(1.); // 1/k
35  this->SetKernelCoef0(1.);
36  this->SetKernelFunctor(NULL);
37  this->SetNu(0.5);
38  this->SetCacheSize(40);
39  this->SetC(1);
40  this->SetEpsilon(1e-3);
41  this->SetP(0.1);
42  this->DoShrinking(true);
43  this->DoProbabilityEstimates(false);
44 
45  m_Parameters.kernel_generic = NULL;
46  m_Parameters.kernel_composed = NULL;
47  m_Parameters.nr_weight = 0;
48  m_Parameters.weight_label = NULL;
49  m_Parameters.weight = NULL;
50 
51  m_Model = NULL;
52 
53  this->Initialize();
54 }
55 
56 template <class TValue, class TLabel>
58 {
59  this->DeleteModel();
60  this->DeleteProblem();
61 }
62 template <class TValue, class TLabel>
63 void
65 {
66  // Initialize model
67  if (!m_Model)
68  {
69  m_Model = new struct svm_model;
70  m_Model->l = 0;
71  m_Model->nr_class = 0;
72  m_Model->SV = NULL;
73  m_Model->sv_coef = NULL;
74  m_Model->rho = NULL;
75  m_Model->label = NULL;
76  m_Model->probA = NULL;
77  m_Model->probB = NULL;
78  m_Model->nSV = NULL;
79 
80  m_ModelUpToDate = false;
81 
82  }
83 
84  // Intialize problem
85  m_Problem.l = 0;
86  m_Problem.y = NULL;
87  m_Problem.x = NULL;
88 
89  m_ProblemUpToDate = false;
90 }
91 
92 template <class TValue, class TLabel>
93 void
95 {
96  this->DeleteProblem();
97  this->DeleteModel();
98 
99  // Clear samples
100  m_Samples.clear();
101 
102  // Initialize values
103  this->Initialize();
104 }
105 
106 template <class TValue, class TLabel>
107 void
109 {
110  svm_free_and_destroy_model(&m_Model);
111  m_Model = NULL;
112 }
113 
114 template <class TValue, class TLabel>
115 void
117 {
118 // Deallocate any existing problem
119  if (m_Problem.y)
120  {
121  delete[] m_Problem.y;
122  m_Problem.y = NULL;
123  }
124 
125  if (m_Problem.x)
126  {
127  for (int i = 0; i < m_Problem.l; ++i)
128  {
129  if (m_Problem.x[i])
130  {
131  delete[] m_Problem.x[i];
132  }
133  }
134  delete[] m_Problem.x;
135  m_Problem.x = NULL;
136  }
137  m_Problem.l = 0;
138  m_ProblemUpToDate = false;
139 }
140 
141 template <class TValue, class TLabel>
142 void
144 {
145  SampleType newSample(measure, label);
146  m_Samples.push_back(newSample);
147  m_ProblemUpToDate = false;
148 }
149 
150 template <class TValue, class TLabel>
151 void
153 {
154  m_Samples.clear();
155  m_ProblemUpToDate = false;
156 }
157 
158 template <class TValue, class TLabel>
159 void
161 {
162  m_Samples = samples;
163  m_ProblemUpToDate = false;
164 }
165 
166 template <class TValue, class TLabel>
167 void
169 {
170  // Check if problem is up-to-date
171  if (m_ProblemUpToDate)
172  {
173  return;
174  }
175 
176  // Get number of samples
177  int probl = m_Samples.size();
178 
179  if (probl < 1)
180  {
181  itkExceptionMacro(<< "No samples, can not build SVM problem.");
182  }
183  otbMsgDebugMacro(<< "Rebuilding problem ...");
184 
185  // Get the size of the samples
186  long int elements = m_Samples[0].first.size() + 1;
187 
188  // Deallocate any previous problem
189  this->DeleteProblem();
190 
191  // Allocate the problem
192  m_Problem.l = probl;
193  m_Problem.y = new double[probl];
194  m_Problem.x = new struct svm_node*[probl];
195 
196  for (int i = 0; i < probl; ++i)
197  {
198  // Initialize labels to 0
199  m_Problem.y[i] = 0;
200  m_Problem.x[i] = new struct svm_node[elements];
201 
202  // Intialize elements (value = 0; index = -1)
203  for (unsigned int j = 0; j < static_cast<unsigned int>(elements); ++j)
204  {
205  m_Problem.x[i][j].index = -1;
206  m_Problem.x[i][j].value = 0;
207  }
208  }
209 
210  // Iterate on the samples
211  typename SamplesVectorType::const_iterator sIt = m_Samples.begin();
212  int sampleIndex = 0;
213  int maxElementIndex = 0;
214 
215  while (sIt != m_Samples.end())
216  {
217 
218  // Get the sample measurement and label
219  MeasurementType measure = sIt->first;
220  LabelType label = sIt->second;
221 
222  // Set the label
223  m_Problem.y[sampleIndex] = label;
224 
225  int elementIndex = 0;
226 
227  // Populate the svm nodes
228  for (typename MeasurementType::const_iterator eIt = measure.begin();
229  eIt != measure.end() && elementIndex < elements; ++eIt, ++elementIndex)
230  {
231  m_Problem.x[sampleIndex][elementIndex].index = elementIndex + 1;
232  m_Problem.x[sampleIndex][elementIndex].value = (*eIt);
233  }
234 
235  // Get the max index
236  if (elementIndex > maxElementIndex)
237  {
238  maxElementIndex = elementIndex;
239  }
240 
241  ++sampleIndex;
242  ++sIt;
243  }
244 
245  // Compute the kernel gamma from maxElementIndex if necessary
246  if (this->GetKernelGamma() == 0
247  && this->GetParameters().kernel_type != COMPOSED
248  && this->GetParameters().kernel_type != GENERIC) this->SetKernelGamma(1.0 / static_cast<double>(maxElementIndex));
249 
250  // problem is up-to-date
251  m_ProblemUpToDate = true;
252 }
253 
254 template <class TValue, class TLabel>
255 double
257 {
258  // Build problem
259  this->BuildProblem();
260 
261  // Check consistency
262  this->ConsistencyCheck();
263 
264  // Get the length of the problem
265  int length = m_Problem.l;
266 
267  // Temporary memory to store cross validation results
268  double *target = new double[length];
269 
270  // Do cross validation
271  svm_cross_validation(&m_Problem, &m_Parameters, nbFolders, target);
272 
273  // Evaluate accuracy
274  int i;
275  double total_correct = 0.;
276 
277  for (i = 0; i < length; ++i)
278  {
279  if (target[i] == m_Problem.y[i])
280  {
281  ++total_correct;
282  }
283  }
284  double accuracy = total_correct / length;
285 
286  // Free temporary memory
287  delete[] target;
288 
289  // return accuracy value
290  return accuracy;
291 }
292 
293 template <class TValue, class TLabel>
294 void
296 {
297  if (m_Parameters.svm_type == ONE_CLASS && this->GetDoProbabilityEstimates())
298  {
299  otbMsgDebugMacro(<< "Disabling SVM probability estimates for ONE_CLASS SVM type.");
300  this->DoProbabilityEstimates(false);
301  }
302 
303  const char* error_msg = svm_check_parameter(&m_Problem, &m_Parameters);
304 
305  if (error_msg)
306  {
307  throw itk::ExceptionObject(__FILE__, __LINE__, error_msg, ITK_LOCATION);
308  }
309 }
310 
311 template <class TValue, class TLabel>
312 void
314 {
315  // If the model is already up-to-date, return
316  if (m_ModelUpToDate)
317  {
318  return;
319  }
320 
321  // Build problem
322  this->BuildProblem();
323 
324  // Check consistency
325  this->ConsistencyCheck();
326 
327  // train the model
328  m_Model = svm_train(&m_Problem, &m_Parameters);
329 
330  // Set the model as up-to-date
331  m_ModelUpToDate = true;
332 }
333 
334 template <class TValue, class TLabel>
337 {
338  // Check if model is up-to-date
339  if (!m_ModelUpToDate)
340  {
341  itkExceptionMacro(<< "Model is not up-to-date, can not predict label");
342  }
343 
344  // Check probability prediction
345  bool predict_probability = svm_check_probability_model(m_Model);
346 
347  if (this->GetSVMType() == ONE_CLASS)
348  {
349  predict_probability = 0;
350  }
351 
352  // Get type and number of classes
353  int svm_type = svm_get_svm_type(m_Model);
354  int nr_class = svm_get_nr_class(m_Model);
355 
356  // Allocate space for labels
357  double *prob_estimates = NULL;
358 
359  // Eventually allocate space for probabilities
360  if (predict_probability)
361  {
362  if (svm_type == NU_SVR || svm_type == EPSILON_SVR)
363  {
365  <<
366  "Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma), sigma="
367  << svm_get_svr_probability(m_Model));
368  }
369  else
370  {
371  prob_estimates = new double[nr_class];
372  }
373  }
374 
375  // Allocate nodes (/TODO if performances problems are related to too
376  // many allocations, a cache approach can be set)
377  struct svm_node * x = new struct svm_node[measure.size() + 1];
378 
379  int valueIndex = 0;
380 
381  // Fill the node
382  for (typename MeasurementType::const_iterator mIt = measure.begin(); mIt != measure.end(); ++mIt, ++valueIndex)
383  {
384  x[valueIndex].index = valueIndex + 1;
385  x[valueIndex].value = (*mIt);
386  }
387 
388  // terminate node
389  x[measure.size()].index = -1;
390  x[measure.size()].value = 0;
391 
392  LabelType label = 0;
393 
394  if (predict_probability && (svm_type == C_SVC || svm_type == NU_SVC))
395  {
396  label = static_cast<LabelType>(svm_predict_probability(m_Model, x, prob_estimates));
397  }
398  else
399  {
400  label = static_cast<LabelType>(svm_predict(m_Model, x));
401  }
402 
403  // Free allocated memory
404  delete[] x;
405 
406  if (prob_estimates)
407  {
408  delete[] prob_estimates;
409  }
410 
411  return label;
412 }
413 
414 template <class TValue, class TLabel>
417 {
418  // Check if model is up-to-date
419  if (!m_ModelUpToDate)
420  {
421  itkExceptionMacro(<< "Model is not up-to-date, can not predict label");
422  }
423 
424  // Allocate nodes (/TODO if performances problems are related to too
425  // many allocations, a cache approach can be set)
426  struct svm_node * x = new struct svm_node[measure.size() + 1];
427 
428  int valueIndex = 0;
429 
430  // Fill the node
431  for (typename MeasurementType::const_iterator mIt = measure.begin(); mIt != measure.end(); ++mIt, ++valueIndex)
432  {
433  x[valueIndex].index = valueIndex + 1;
434  x[valueIndex].value = (*mIt);
435  }
436 
437  // terminate node
438  x[measure.size()].index = -1;
439  x[measure.size()].value = 0;
440 
441  // Intialize distances vector
442  DistancesVectorType distances(m_Model->nr_class*(m_Model->nr_class - 1) / 2);
443 
444  // predict distances vector
445  svm_predict_values(m_Model, x, (double*) (distances.GetDataPointer()));
446 
447  // Free allocated memory
448  delete[] x;
449 
450  return (distances);
451 }
452 
453 template <class TValue, class TLabel>
456 {
457  // Check if model is up-to-date
458  if (!m_ModelUpToDate)
459  {
460  itkExceptionMacro(<< "Model is not up-to-date, can not predict probabilities");
461  }
462 
463  if (svm_check_probability_model(m_Model) == 0)
464  {
465  throw itk::ExceptionObject(__FILE__, __LINE__,
466  "Model does not support probability estimates", ITK_LOCATION);
467  }
468 
469  // Get number of classes
470  int nr_class = svm_get_nr_class(m_Model);
471 
472  // Allocate nodes (/TODO if performances problems are related to too
473  // many allocations, a cache approach can be set)
474  struct svm_node * x = new struct svm_node[measure.size() + 1];
475 
476  int valueIndex = 0;
477 
478  // Fill the node
479  for (typename MeasurementType::const_iterator mIt = measure.begin(); mIt != measure.end(); ++mIt, ++valueIndex)
480  {
481  x[valueIndex].index = valueIndex + 1;
482  x[valueIndex].value = (*mIt);
483  }
484 
485  // Termination node
486  x[measure.size()].index = -1;
487  x[measure.size()].value = 0;
488 
489  double* dec_values = new double[nr_class];
490  svm_predict_probability(m_Model, x, dec_values);
491 
492  // Reorder values in increasing class label
493  int* labels = m_Model->label;
494  std::vector<int> orderedLabels(nr_class);
495  std::copy(labels, labels + nr_class, orderedLabels.begin());
496  std::sort(orderedLabels.begin(), orderedLabels.end());
497 
498  ProbabilitiesVectorType probabilities(nr_class);
499  for (int i = 0; i < nr_class; ++i)
500  {
501  // svm_predict_probability is such that "dec_values[i]" corresponds to label "labels[i]"
502  std::vector<int>::iterator it = std::find(orderedLabels.begin(), orderedLabels.end(), labels[i]);
503  probabilities[it - orderedLabels.begin()] = dec_values[i];
504  }
505 
506  // Free allocated memory
507  delete[] x;
508  delete[] dec_values;
509 
510  return probabilities;
511 }
512 
513 template <class TValue, class TLabel>
514 void
516 {
517  this->DeleteModel();
518  m_Model = svm_copy_model(aModel);
519  m_ModelUpToDate = true;
520 }
521 
522 template <class TValue, class TLabel>
523 void
524 SVMModel<TValue, TLabel>::SaveModel(const char* model_file_name) const
525 {
526  if (svm_save_model(model_file_name, m_Model) != 0)
527  {
528  itkExceptionMacro(<< "Problem while saving SVM model "
529  << std::string(model_file_name));
530  }
531 }
532 
533 template <class TValue, class TLabel>
534 void
535 SVMModel<TValue, TLabel>::LoadModel(const char* model_file_name)
536 {
537  this->DeleteModel();
538  m_Model = svm_load_model(model_file_name, m_Parameters.kernel_generic);
539  if (m_Model == 0)
540  {
541  itkExceptionMacro(<< "Problem while loading SVM model "
542  << std::string(model_file_name));
543  }
544  m_Parameters = m_Model->param;
545  m_ModelUpToDate = true;
546 }
547 
548 template <class TValue, class TLabel>
551 {
552  Pointer modelCopy = New();
553  modelCopy->SetModel(m_Model);
554  // We do not copy the problem to avoid sharing allocated memory
555  return modelCopy;
556 }
557 
558 template <class TValue, class TLabel>
559 void
560 SVMModel<TValue, TLabel>::PrintSelf(std::ostream& os, itk::Indent indent) const
561 {
562  Superclass::PrintSelf(os, indent);
563 }
564 
565 template <class TValue, class TLabel>
566 void
568 {
569  // TODO: rewrite this to check memory allocation
570 
571  // erase the old SV
572  // delete just the first element, it destoyes the whole pointers (cf SV filling with x_space)
573  delete[] (m_Model->SV[0]);
574 
575  for (int n = 0; n < m_Model->l; ++n)
576  {
577  m_Model->SV[n] = NULL;
578  }
579  delete[] (m_Model->SV);
580  m_Model->SV = NULL;
581 
582  m_Model->SV = new struct svm_node*[m_Model->l];
583 
584  // copy new SV values
585  svm_node **SV = m_Model->SV;
586 
587  // Compute the total number of SV elements.
588  unsigned int elements = 0;
589  for (int p = 0; p < nbOfSupportVector; ++p)
590  {
591  //std::cout << p << " ";
592  const svm_node *tempNode = sv[p];
593  //std::cout << p << " ";
594  while (tempNode->index != -1)
595  {
596  tempNode++;
597  ++elements;
598  }
599  ++elements; // for -1 values
600  }
601 
602  if (m_Model->l > 0)
603  {
604  SV[0] = new struct svm_node[elements];
605  memcpy(SV[0], sv[0], sizeof(svm_node*) * elements);
606  }
607  svm_node *x_space = SV[0];
608 
609  int j = 0;
610  for (int i = 0; i < m_Model->l; ++i)
611  {
612  // SV
613  SV[i] = &x_space[j];
614  const svm_node *p = sv[i];
615  svm_node * pCpy = SV[i];
616  while (p->index != -1)
617  {
618  pCpy->index = p->index;
619  pCpy->value = p->value;
620  ++p;
621  ++pCpy;
622  ++j;
623  }
624  pCpy->index = -1;
625  ++j;
626  }
627 
628  if (m_Model->l > 0)
629  {
630  delete[] SV[0];
631  }
632 }
633 
634 template <class TValue, class TLabel>
635 void
636 SVMModel<TValue, TLabel>::SetAlpha(double ** alpha, int nbOfSupportVector)
637 {
638  // TODO: Check memory allocation
639 
640  // Erase the old sv_coef
641  for (int i = 0; i < m_Model->nr_class - 1; ++i)
642  {
643  delete[] m_Model->sv_coef[i];
644  }
645  delete[] m_Model->sv_coef;
646 
647  // copy new sv_coef values
648  m_Model->sv_coef = new double*[m_Model->nr_class - 1];
649  for (int i = 0; i < m_Model->nr_class - 1; ++i)
650  m_Model->sv_coef[i] = new double[m_Model->l];
651 
652  for (int i = 0; i < m_Model->l; ++i)
653  {
654  // sv_coef
655  for (int k = 0; k < m_Model->nr_class - 1; ++k)
656  {
657  m_Model->sv_coef[k][i] = alpha[k][i];
658  }
659  }
660 }
661 
662 } // end namespace otb
663 
664 #endif

Generated at Sat Jul 19 2014 16:27:33 for Orfeo Toolbox with doxygen 1.8.3.1