21 #ifndef otbListSampleGenerator_hxx
22 #define otbListSampleGenerator_hxx
26 #include "itkImageRegionConstIteratorWithIndex.h"
63 template <
class TImage,
class TVectorData>
65 : m_MaxTrainingSize(-1),
66 m_MaxValidationSize(-1),
67 m_ValidationTrainingProportion(0.0),
69 m_PolygonEdgeInclusion(false),
74 this->SetNumberOfRequiredInputs(2);
75 this->SetNumberOfRequiredOutputs(4);
78 this->itk::ProcessObject::SetNthOutput(0, this->
MakeOutput(0).GetPointer());
79 this->itk::ProcessObject::SetNthOutput(1, this->
MakeOutput(1).GetPointer());
80 this->itk::ProcessObject::SetNthOutput(2, this->
MakeOutput(2).GetPointer());
81 this->itk::ProcessObject::SetNthOutput(3, this->
MakeOutput(3).GetPointer());
86 template <
class TImage,
class TVectorData>
89 this->ProcessObject::SetNthInput(0,
const_cast<ImageType*
>(image));
92 template <
class TImage,
class TVectorData>
95 if (this->GetNumberOfInputs() < 1)
100 return static_cast<const ImageType*
>(this->ProcessObject::GetInput(0));
103 template <
class TImage,
class TVectorData>
106 this->ProcessObject::SetNthInput(1,
const_cast<VectorDataType*
>(vectorData));
111 template <
class TImage,
class TVectorData>
114 if (this->GetNumberOfInputs() < 2)
119 return static_cast<const VectorDataType*
>(this->ProcessObject::GetInput(1));
122 template <
class TImage,
class TVectorData>
129 output =
static_cast<itk::DataObject*
>(ListSampleType::New().GetPointer());
132 output =
static_cast<itk::DataObject*
>(ListLabelType::New().GetPointer());
135 output =
static_cast<itk::DataObject*
>(ListSampleType::New().GetPointer());
138 output =
static_cast<itk::DataObject*
>(ListLabelType::New().GetPointer());
141 output =
static_cast<itk::DataObject*
>(ListSampleType::New().GetPointer());
147 template <
class TImage,
class TVectorData>
150 return dynamic_cast<ListSampleType*
>(this->itk::ProcessObject::GetOutput(0));
153 template <
class TImage,
class TVectorData>
156 return dynamic_cast<ListLabelType*
>(this->itk::ProcessObject::GetOutput(1));
160 template <
class TImage,
class TVectorData>
163 return dynamic_cast<ListSampleType*
>(this->itk::ProcessObject::GetOutput(2));
168 template <
class TImage,
class TVectorData>
171 return dynamic_cast<ListLabelType*
>(this->itk::ProcessObject::GetOutput(3));
174 template <
class TImage,
class TVectorData>
185 typename ImageType::RegionType dummyRegion;
186 typename ImageType::SizeType dummySize;
188 dummyRegion.SetSize(dummySize);
189 img->SetRequestedRegion(dummyRegion);
194 template <
class TImage,
class TVectorData>
209 this->GenerateClassStatistics();
211 this->ComputeClassSelectionProbability();
214 trainingListSample->Clear();
215 trainingListLabel->Clear();
216 validationListSample->Clear();
217 validationListLabel->Clear();
220 trainingListSample->SetMeasurementVectorSize(image->GetNumberOfComponentsPerPixel());
222 trainingListLabel->SetMeasurementVectorSize(1);
223 validationListSample->SetMeasurementVectorSize(image->GetNumberOfComponentsPerPixel());
225 validationListLabel->SetMeasurementVectorSize(1);
227 m_ClassesSamplesNumberTraining.clear();
228 m_ClassesSamplesNumberValidation.clear();
230 typename ImageType::RegionType imageLargestRegion = image->GetLargestPossibleRegion();
233 for (itVector.GoToBegin(); !itVector.IsAtEnd(); ++itVector)
235 if (itVector.Get()->IsPolygonFeature())
241 const bool hasIntersection = polygonRegion.Crop(imageLargestRegion);
242 if (!hasIntersection)
247 image->SetRequestedRegion(polygonRegion);
248 image->PropagateRequestedRegion();
249 image->UpdateOutputData();
251 typedef itk::ImageRegionConstIteratorWithIndex<ImageType> IteratorType;
252 IteratorType it(image, polygonRegion);
254 for (it.GoToBegin(); !it.IsAtEnd(); ++it)
256 itk::ContinuousIndex<double, 2> point;
257 image->TransformIndexToPhysicalPoint(it.GetIndex(), point);
259 if (exteriorRing->IsInside(point) || (this->GetPolygonEdgeInclusion() && exteriorRing->IsOnEdge(point)))
263 bool isInsideInteriorRing =
false;
264 for (
typename PolygonListType::Iterator interiorRing = interiorRings->Begin(); interiorRing != interiorRings->End(); ++interiorRing)
266 if (interiorRing.Get()->IsInside(point) || (this->GetPolygonEdgeInclusion() && interiorRing.Get()->IsOnEdge(point)))
268 isInsideInteriorRing =
true;
272 if (isInsideInteriorRing)
277 double randomValue = m_RandomGenerator->GetUniformVariate(0.0, 1.0);
278 if (randomValue < m_ClassesProbTraining[itVector.Get()->GetFieldAsInt(m_ClassKey)])
281 trainingListSample->PushBack(it.Get());
282 trainingListLabel->PushBack(itVector.Get()->GetFieldAsInt(m_ClassKey));
283 m_ClassesSamplesNumberTraining[itVector.Get()->GetFieldAsInt(m_ClassKey)] += 1;
285 else if (randomValue <
286 m_ClassesProbTraining[itVector.Get()->GetFieldAsInt(m_ClassKey)] + m_ClassesProbValidation[itVector.Get()->GetFieldAsInt(m_ClassKey)])
289 validationListSample->PushBack(it.Get());
290 validationListLabel->PushBack(itVector.Get()->GetFieldAsInt(m_ClassKey));
291 m_ClassesSamplesNumberValidation[itVector.Get()->GetFieldAsInt(m_ClassKey)] += 1;
299 assert(trainingListSample->Size() == trainingListLabel->Size());
300 assert(validationListSample->Size() == validationListLabel->Size());
301 this->UpdateProgress(1.0f);
304 template <
class TImage,
class TVectorData>
307 m_ClassesSize.clear();
310 typename VectorDataType::ConstPointer vectorData = this->GetInputVectorData();
314 for (itVector.GoToBegin(); !itVector.IsAtEnd(); ++itVector)
317 if (datanode->IsPolygonFeature())
319 double area = GetPolygonAreaInPixelsUnits(datanode, image);
320 m_ClassesSize[datanode->GetFieldAsInt(m_ClassKey)] += area;
323 m_NumberOfClasses = m_ClassesSize.size();
326 template <
class TImage,
class TVectorData>
329 m_ClassesProbTraining.clear();
330 m_ClassesProbValidation.clear();
333 if (m_ClassesSize.empty())
335 itkGenericExceptionMacro(<<
"No training sample found inside image");
339 double minSize = itk::NumericTraits<double>::max();
340 for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesSize.begin(); itmap != m_ClassesSize.end(); ++itmap)
342 if (minSize > itmap->second)
344 minSize = itmap->second;
349 double minSizeTraining = minSize * (1.0 - m_ValidationTrainingProportion);
350 double minSizeValidation = minSize * m_ValidationTrainingProportion;
355 if ((m_MaxTrainingSize != -1) && (m_MaxTrainingSize < minSizeTraining))
357 minSizeTraining = m_MaxTrainingSize;
359 if ((m_MaxValidationSize != -1) && (m_MaxValidationSize < minSizeValidation))
361 minSizeValidation = m_MaxValidationSize;
365 for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesSize.begin(); itmap != m_ClassesSize.end(); ++itmap)
369 m_ClassesProbTraining[itmap->first] = minSizeTraining / itmap->second;
370 m_ClassesProbValidation[itmap->first] = minSizeValidation / itmap->second;
374 long int maxSizeT = (itmap->second) * (1.0 - m_ValidationTrainingProportion);
375 long int maxSizeV = (itmap->second) * m_ValidationTrainingProportion;
378 double correctionRatioTrain = 1.0;
379 if ((m_MaxTrainingSize > -1) && (m_MaxTrainingSize < maxSizeT))
381 correctionRatioTrain = (double)(m_MaxTrainingSize) / (double)(maxSizeT);
383 double correctionRatioValid = 1.0;
384 if ((m_MaxValidationSize > -1) && (m_MaxValidationSize < maxSizeV))
386 correctionRatioValid = (double)(m_MaxValidationSize) / (double)(maxSizeV);
388 double correctionRatio = std::min(correctionRatioTrain, correctionRatioValid);
389 m_ClassesProbTraining[itmap->first] = correctionRatio * (1.0 - m_ValidationTrainingProportion);
390 m_ClassesProbValidation[itmap->first] = correctionRatio * m_ValidationTrainingProportion;
394 template <
class TImage,
class TVectorData>
397 const double pixelArea = std::abs(image->GetSignedSpacing()[0] * image->GetSignedSpacing()[1]);
401 double area = exteriorRing->GetArea() / pixelArea;
405 for (
typename PolygonListType::Iterator interiorRing = interiorRings->Begin(); interiorRing != interiorRings->End(); ++interiorRing)
407 area -= interiorRing.Get()->GetArea() / pixelArea;
413 template <
class TImage,
class TVectorData>
416 os << indent <<
"* MaxTrainingSize: " << m_MaxTrainingSize <<
"\n";
417 os << indent <<
"* MaxValidationSize: " << m_MaxValidationSize <<
"\n";
418 os << indent <<
"* Proportion: " << m_ValidationTrainingProportion <<
"\n";
419 os << indent <<
"* Input data:\n";
420 if (m_ClassesSize.empty())
422 os << indent <<
"Empty\n";
426 for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesSize.begin(); itmap != m_ClassesSize.end(); ++itmap)
428 os << indent << itmap->first <<
": " << itmap->second <<
"\n";
432 os <<
"\n" << indent <<
"* Training set:\n";
433 if (m_ClassesProbTraining.empty())
435 os << indent <<
"Not computed\n";
439 os << indent <<
"** Selection probability:\n";
440 for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesProbTraining.begin(); itmap != m_ClassesProbTraining.end(); ++itmap)
442 os << indent << itmap->first <<
": " << itmap->second <<
"\n";
444 os << indent <<
"** Number of selected samples:\n";
445 for (std::map<ClassLabelType, int>::const_iterator itmap = m_ClassesSamplesNumberTraining.begin(); itmap != m_ClassesSamplesNumberTraining.end(); ++itmap)
447 os << indent << itmap->first <<
": " << itmap->second <<
"\n";
451 os <<
"\n" << indent <<
"* Validation set:\n";
452 if (m_ClassesProbValidation.empty())
454 os << indent <<
"Not computed\n";
458 os << indent <<
"** Selection probability:\n";
459 for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesProbValidation.begin(); itmap != m_ClassesProbValidation.end(); ++itmap)
461 os << indent << itmap->first <<
": " << itmap->second <<
"\n";
463 os << indent <<
"** Number of selected samples:\n";
464 for (std::map<ClassLabelType, int>::const_iterator itmap = m_ClassesSamplesNumberValidation.begin(); itmap != m_ClassesSamplesNumberValidation.end();
467 os << indent << itmap->first <<
": " << itmap->second <<
"\n";