49 int main(
int argc,
char* argv[])
63 typedef float InputPixelType;
65 typedef std::vector<InputPixelType> InputVectorType;
66 typedef int LabelPixelType;
68 const unsigned int Dimension = 2;
93 typedef MeasurePointSetType::PointType MeasurePointType;
94 typedef LabelPointSetType::PointType LabelPointType;
96 typedef MeasurePointSetType::PointsContainer MeasurePointsContainer;
97 typedef LabelPointSetType::PointsContainer LabelPointsContainer;
99 MeasurePointSetType::Pointer tPSet = MeasurePointSetType::New();
100 MeasurePointsContainer::Pointer tCont = MeasurePointsContainer::New();
114 unsigned int pointId;
119 for (pointId = 0; pointId < 100; pointId++)
124 int x_coord = lowest +
static_cast<int>(range * (rand() / (RAND_MAX + 1.0)));
125 int y_coord = lowest +
static_cast<int>(range * (rand() / (RAND_MAX + 1.0)));
127 std::cout <<
"coords : " << x_coord <<
" " << y_coord << std::endl;
140 InputVectorType measure;
141 measure.push_back(static_cast<InputPixelType>((x_coord * 1.0 -
143 measure.push_back(static_cast<InputPixelType>((y_coord * 1.0 -
154 tCont->InsertElement(pointId, tP);
155 tPSet->SetPointData(pointId, measure);
167 tPSet->SetPoints(tCont);
183 SampleType::Pointer sample = SampleType::New();
194 sample->SetPointSet(tPSet);
206 typedef otb::SVMModel<SampleType::MeasurementVectorType::ValueType,
207 LabelPixelType> ModelType;
209 ModelType::Pointer model = ModelType::New();
221 model->LoadModel(argv[1]);
235 ClassifierType::Pointer classifier = ClassifierType::New();
247 int numberOfClasses = model->GetNumberOfClasses();
248 classifier->SetNumberOfClasses(numberOfClasses);
249 classifier->SetModel(model);
250 classifier->SetSample(sample.GetPointer());
251 classifier->Update();
263 ClassifierType::OutputType* membershipSample =
264 classifier->GetOutput();
266 ClassifierType::OutputType::ConstIterator m_iter =
267 membershipSample->Begin();
268 ClassifierType::OutputType::ConstIterator m_last =
269 membershipSample->End();
282 while (m_iter != m_last)
293 ClassifierType::ClassLabelType
label = m_iter.GetClassLabel();
303 InputVectorType measure;
305 tPSet->GetPointData(pointId, &measure);
307 ClassifierType::ClassLabelType expectedLabel;
308 if (measure[0] < measure[1]) expectedLabel = -1;
309 else expectedLabel = 1;
311 double dist = fabs(measure[0] - measure[1]);
313 if (label != expectedLabel) error++;
315 std::cout << int(label) <<
"/" << int(expectedLabel) <<
" --- " << dist <<
322 std::cout <<
"Error = " << error / pointId <<
" % " << std::endl;