Orfeo Toolbox  3.16
SVMPointSetClassificationExample.cxx
Go to the documentation of this file.
1 /*=========================================================================
2 
3  Program: ORFEO Toolbox
4  Language: C++
5  Date: $Date$
6  Version: $Revision$
7 
8 
9  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
10  See OTBCopyright.txt for details.
11 
12 
13  This software is distributed WITHOUT ANY WARRANTY; without even
14  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
15  PURPOSE. See the above copyright notices for more information.
16 
17 =========================================================================*/
18 
19 
20 // Software Guide : BeginCommandLineArgs
21 // INPUTS: {svm_model.svn}
22 // OUTPUTS:
23 // Software Guide : EndCommandLineArgs
24 
25 #include "itkMacro.h"
26 #include "itkPointSet.h"
27 #include <iostream>
28 #include <cstdlib>
29 
30 // Software Guide : BeginLatex
31 // This example illustrates the use of the
32 // \doxygen{otb}{SVMClassifier} class for performing SVM
33 // classification on pointsets.
34 // The first thing to do is include the header file for the
35 // class. Since the \doxygen{otb}{SVMClassifier} takes
36 // \doxygen{itk}{ListSample}s as input, the class
37 // \doxygen{itk}{PointSetToListAdaptor} is needed.
38 //
39 // We start by including the needed header files.
40 //
41 // Software Guide : EndLatex
42 
43 // Software Guide : BeginCodeSnippet
45 #include "itkListSample.h"
46 #include "otbSVMClassifier.h"
47 // Software Guide : EndCodeSnippet
48 
49 int main(int argc, char* argv[])
50 {
51 // Software Guide : BeginLatex
52 //
53 // In the framework of supervised learning and classification, we will
54 // always use feature vectors for the characterization of the
55 // classes. On the other hand, the class labels are scalar
56 // values. Here, we start by defining the type of the features as the
57 // \code{PixelType}, which will be used to define the feature
58 // \code{VectorType}. We also declare the type for the labels.
59 //
60 // Software Guide : EndLatex
61 
62 // Software Guide : BeginCodeSnippet
63  typedef float InputPixelType;
64 
65  typedef std::vector<InputPixelType> InputVectorType;
66  typedef int LabelPixelType;
67 // Software Guide : EndCodeSnippet
68  const unsigned int Dimension = 2;
69 
70 // Software Guide : BeginLatex
71 //
72 // We can now proceed to define the point sets used for storing the
73 // features and the labels.
74 //
75 // Software Guide : EndLatex
76 
77 // Software Guide : BeginCodeSnippet
78  typedef itk::PointSet<InputVectorType, Dimension> MeasurePointSetType;
79 
80  typedef itk::PointSet<LabelPixelType, Dimension> LabelPointSetType;
81 // Software Guide : EndCodeSnippet
82 
83 // Software Guide : BeginLatex
84 //
85 // We will need to get access to the data stored in the point sets, so
86 // we define the appropriate for the points and the points containers
87 // used by the point sets (see the section \ref{sec:PointSetSection}
88 // for more information on how to use point sets).
89 //
90 // Software Guide : EndLatex
91 
92 // Software Guide : BeginCodeSnippet
93  typedef MeasurePointSetType::PointType MeasurePointType;
94  typedef LabelPointSetType::PointType LabelPointType;
95 
96  typedef MeasurePointSetType::PointsContainer MeasurePointsContainer;
97  typedef LabelPointSetType::PointsContainer LabelPointsContainer;
98 
99  MeasurePointSetType::Pointer tPSet = MeasurePointSetType::New();
100  MeasurePointsContainer::Pointer tCont = MeasurePointsContainer::New();
101 // Software Guide : EndCodeSnippet
102 
103 // Software Guide : BeginLatex
104 //
105 // We need now to build the test set for the SVM. In this
106 // simple example, we will build a SVM who classes points depending on
107 // which side of the line $x=y$ they are located. We start by
108 // generating 500 random points.
109 //
110 // Software Guide : EndLatex
111 
112  srand(0);
113 
114  unsigned int pointId;
115 // Software Guide : BeginCodeSnippet
116  int lowest = 0;
117  int range = 1000;
118 
119  for (pointId = 0; pointId < 100; pointId++)
120  {
121 
122  MeasurePointType tP;
123 
124  int x_coord = lowest + static_cast<int>(range * (rand() / (RAND_MAX + 1.0)));
125  int y_coord = lowest + static_cast<int>(range * (rand() / (RAND_MAX + 1.0)));
126 
127  std::cout << "coords : " << x_coord << " " << y_coord << std::endl;
128  tP[0] = x_coord;
129  tP[1] = y_coord;
130 // Software Guide : EndCodeSnippet
131 
132 // Software Guide : BeginLatex
133 //
134 // We push the features in the vector after a normalization which is
135 // useful for SVM convergence.
136 //
137 // Software Guide : EndLatex
138 
139 // Software Guide : BeginCodeSnippet
140  InputVectorType measure;
141  measure.push_back(static_cast<InputPixelType>((x_coord * 1.0 -
142  lowest) / range));
143  measure.push_back(static_cast<InputPixelType>((y_coord * 1.0 -
144  lowest) / range));
145 // Software Guide : EndCodeSnippet
146 
147 // Software Guide : BeginLatex
148 //
149 // And we insert the points in the points container.
150 //
151 // Software Guide : EndLatex
152 
153 // Software Guide : BeginCodeSnippet
154  tCont->InsertElement(pointId, tP);
155  tPSet->SetPointData(pointId, measure);
156 
157  }
158 // Software Guide : EndCodeSnippet
159 
160 // Software Guide : BeginLatex
161 //
162 // After the loop, we set the points container to the point set.
163 //
164 // Software Guide : EndLatex
165 
166 // Software Guide : BeginCodeSnippet
167  tPSet->SetPoints(tCont);
168 // Software Guide : EndCodeSnippet
169 
170 // Software Guide : BeginLatex
171 //
172 // Once the pointset is ready, we must transform it to a sample which
173 // is compatible with the classification framework. We will use a
174 // \doxygen{itk}{Statistics::PointSetToListAdaptor} for this
175 // task. This class is templated over the point set type used for
176 // storing the measures.
177 //
178 // Software Guide : EndLatex
179 
180 // Software Guide : BeginCodeSnippet
182  SampleType;
183  SampleType::Pointer sample = SampleType::New();
184 // Software Guide : EndCodeSnippet
185 
186 // Software Guide : BeginLatex
187 //
188 // After instantiation, we can set the point set as an imput of our
189 // sample adaptor.
190 //
191 // Software Guide : EndLatex
192 
193 // Software Guide : BeginCodeSnippet
194  sample->SetPointSet(tPSet);
195 // Software Guide : EndCodeSnippet
196 
197 // Software Guide : BeginLatex
198 //
199 // Now, we need to declare the SVM model which is to be used by the
200 // classifier. The SVM model is templated over the type of value used
201 // for the measures and the type of pixel used for the labels.
202 //
203 // Software Guide : EndLatex
204 
205 // Software Guide : BeginCodeSnippet
206  typedef otb::SVMModel<SampleType::MeasurementVectorType::ValueType,
207  LabelPixelType> ModelType;
208 
209  ModelType::Pointer model = ModelType::New();
210 // Software Guide : EndCodeSnippet
211 
212 // Software Guide : BeginLatex
213 //
214 // After instantiation, we can load a model saved to a file (see
215 // section \ref{sec:LearningWithPointSets} for an example of model
216 // estimation and storage to a file).
217 //
218 // Software Guide : EndLatex
219 
220 // Software Guide : BeginCodeSnippet
221  model->LoadModel(argv[1]);
222 // Software Guide : EndCodeSnippet
223 
224 // Software Guide : BeginLatex
225 //
226 // We have now all the elements to create a classifier. The classifier
227 // is templated over the sample type (the type of the data to be
228 // classified) and the label type (the type of the output of the classifier).
229 //
230 // Software Guide : EndLatex
231 
232 // Software Guide : BeginCodeSnippet
233  typedef otb::SVMClassifier<SampleType, LabelPixelType> ClassifierType;
234 
235  ClassifierType::Pointer classifier = ClassifierType::New();
236 // Software Guide : EndCodeSnippet
237 
238 // Software Guide : BeginLatex
239 //
240 // We set the classifier parameters : number of classes, SVM model,
241 // the sample data. And we trigger the classification process by
242 // calling the \code{Update} method.
243 //
244 // Software Guide : EndLatex
245 
246 // Software Guide : BeginCodeSnippet
247  int numberOfClasses = model->GetNumberOfClasses();
248  classifier->SetNumberOfClasses(numberOfClasses);
249  classifier->SetModel(model);
250  classifier->SetSample(sample.GetPointer());
251  classifier->Update();
252 // Software Guide : EndCodeSnippet
253 
254 // Software Guide : BeginLatex
255 //
256 // After the classification step, we usually want to get the
257 // results. The classifier gives an output under the form of a sample
258 // list. This list supports the classical STL iterators.
259 //
260 // Software Guide : EndLatex
261 
262 // Software Guide : BeginCodeSnippet
263  ClassifierType::OutputType* membershipSample =
264  classifier->GetOutput();
265 
266  ClassifierType::OutputType::ConstIterator m_iter =
267  membershipSample->Begin();
268  ClassifierType::OutputType::ConstIterator m_last =
269  membershipSample->End();
270 // Software Guide : EndCodeSnippet
271 
272 // Software Guide : BeginLatex
273 //
274 // We will iterate through the list, get the labels and compute the
275 // classification error.
276 //
277 // Software Guide : EndLatex
278 
279 // Software Guide : BeginCodeSnippet
280  double error = 0.0;
281  pointId = 0;
282  while (m_iter != m_last)
283  {
284 // Software Guide : EndCodeSnippet
285 
286 // Software Guide : BeginLatex
287 //
288 // We get the label for each point.
289 //
290 // Software Guide : EndLatex
291 
292 // Software Guide : BeginCodeSnippet
293  ClassifierType::ClassLabelType label = m_iter.GetClassLabel();
294 // Software Guide : EndCodeSnippet
295 
296 // Software Guide : BeginLatex
297 //
298 // And we compare it to the corresponding one of the test set.
299 //
300 // Software Guide : EndLatex
301 
302 // Software Guide : BeginCodeSnippet
303  InputVectorType measure;
304 
305  tPSet->GetPointData(pointId, &measure);
306 
307  ClassifierType::ClassLabelType expectedLabel;
308  if (measure[0] < measure[1]) expectedLabel = -1;
309  else expectedLabel = 1;
310 
311  double dist = fabs(measure[0] - measure[1]);
312 
313  if (label != expectedLabel) error++;
314 
315  std::cout << int(label) << "/" << int(expectedLabel) << " --- " << dist <<
316  std::endl;
317 
318  ++pointId;
319  ++m_iter;
320  }
321 
322  std::cout << "Error = " << error / pointId << " % " << std::endl;
323 // Software Guide : EndCodeSnippet
324 
325  return EXIT_SUCCESS;
326 }

Generated at Sun May 19 2013 01:02:00 for Orfeo Toolbox with doxygen 1.8.3.1