21 #ifndef otbListSampleGenerator_hxx
22 #define otbListSampleGenerator_hxx
26 #include "itkImageRegionConstIteratorWithIndex.h"
34 template <
class TImage,
class TVectorData>
36 : m_MaxTrainingSize(-1),
37 m_MaxValidationSize(-1),
38 m_ValidationTrainingProportion(0.0),
40 m_PolygonEdgeInclusion(false),
45 this->SetNumberOfRequiredInputs(2);
46 this->SetNumberOfRequiredOutputs(4);
49 this->itk::ProcessObject::SetNthOutput(0, this->
MakeOutput(0).GetPointer());
50 this->itk::ProcessObject::SetNthOutput(1, this->
MakeOutput(1).GetPointer());
51 this->itk::ProcessObject::SetNthOutput(2, this->
MakeOutput(2).GetPointer());
52 this->itk::ProcessObject::SetNthOutput(3, this->
MakeOutput(3).GetPointer());
57 template <
class TImage,
class TVectorData>
60 this->ProcessObject::SetNthInput(0,
const_cast<ImageType*
>(image));
63 template <
class TImage,
class TVectorData>
66 if (this->GetNumberOfInputs() < 1)
71 return static_cast<const ImageType*
>(this->ProcessObject::GetInput(0));
74 template <
class TImage,
class TVectorData>
77 this->ProcessObject::SetNthInput(1,
const_cast<VectorDataType*
>(vectorData));
82 template <
class TImage,
class TVectorData>
85 if (this->GetNumberOfInputs() < 2)
90 return static_cast<const VectorDataType*
>(this->ProcessObject::GetInput(1));
93 template <
class TImage,
class TVectorData>
100 output =
static_cast<itk::DataObject*
>(ListSampleType::New().GetPointer());
103 output =
static_cast<itk::DataObject*
>(ListLabelType::New().GetPointer());
106 output =
static_cast<itk::DataObject*
>(ListSampleType::New().GetPointer());
109 output =
static_cast<itk::DataObject*
>(ListLabelType::New().GetPointer());
112 output =
static_cast<itk::DataObject*
>(ListSampleType::New().GetPointer());
118 template <
class TImage,
class TVectorData>
121 return dynamic_cast<ListSampleType*
>(this->itk::ProcessObject::GetOutput(0));
124 template <
class TImage,
class TVectorData>
127 return dynamic_cast<ListLabelType*
>(this->itk::ProcessObject::GetOutput(1));
131 template <
class TImage,
class TVectorData>
134 return dynamic_cast<ListSampleType*
>(this->itk::ProcessObject::GetOutput(2));
139 template <
class TImage,
class TVectorData>
142 return dynamic_cast<ListLabelType*
>(this->itk::ProcessObject::GetOutput(3));
145 template <
class TImage,
class TVectorData>
156 typename ImageType::RegionType dummyRegion;
157 typename ImageType::SizeType dummySize;
159 dummyRegion.SetSize(dummySize);
160 img->SetRequestedRegion(dummyRegion);
165 template <
class TImage,
class TVectorData>
180 this->GenerateClassStatistics();
182 this->ComputeClassSelectionProbability();
185 trainingListSample->Clear();
186 trainingListLabel->Clear();
187 validationListSample->Clear();
188 validationListLabel->Clear();
191 trainingListSample->SetMeasurementVectorSize(image->GetNumberOfComponentsPerPixel());
193 trainingListLabel->SetMeasurementVectorSize(1);
194 validationListSample->SetMeasurementVectorSize(image->GetNumberOfComponentsPerPixel());
196 validationListLabel->SetMeasurementVectorSize(1);
198 m_ClassesSamplesNumberTraining.clear();
199 m_ClassesSamplesNumberValidation.clear();
201 typename ImageType::RegionType imageLargestRegion = image->GetLargestPossibleRegion();
203 auto itVectorPair = vectorData->GetIteratorPair();
204 auto currentIt = itVectorPair.first;
205 for (; currentIt != itVectorPair.second; ++currentIt)
207 if (vectorData->Get(currentIt)->IsPolygonFeature())
209 PolygonPointerType exteriorRing = vectorData->Get(currentIt)->GetPolygonExteriorRing();
213 const bool hasIntersection = polygonRegion.Crop(imageLargestRegion);
214 if (!hasIntersection)
219 image->SetRequestedRegion(polygonRegion);
220 image->PropagateRequestedRegion();
221 image->UpdateOutputData();
223 typedef itk::ImageRegionConstIteratorWithIndex<ImageType> IteratorType;
224 IteratorType it(image, polygonRegion);
226 for (it.GoToBegin(); !it.IsAtEnd(); ++it)
228 itk::ContinuousIndex<double, 2> point;
229 image->TransformIndexToPhysicalPoint(it.GetIndex(), point);
231 if (exteriorRing->IsInside(point) || (this->GetPolygonEdgeInclusion() && exteriorRing->IsOnEdge(point)))
235 bool isInsideInteriorRing =
false;
236 for (
typename PolygonListType::Iterator interiorRing = interiorRings->Begin(); interiorRing != interiorRings->End(); ++interiorRing)
238 if (interiorRing.Get()->IsInside(point) || (this->GetPolygonEdgeInclusion() && interiorRing.Get()->IsOnEdge(point)))
240 isInsideInteriorRing =
true;
244 if (isInsideInteriorRing)
249 double randomValue = m_RandomGenerator->GetUniformVariate(0.0, 1.0);
250 if (randomValue < m_ClassesProbTraining[vectorData->Get(currentIt)->GetFieldAsInt(m_ClassKey)])
253 trainingListSample->PushBack(it.Get());
254 trainingListLabel->PushBack(vectorData->Get(currentIt)->GetFieldAsInt(m_ClassKey));
255 m_ClassesSamplesNumberTraining[vectorData->Get(currentIt)->GetFieldAsInt(m_ClassKey)] += 1;
257 else if (randomValue <
258 m_ClassesProbTraining[vectorData->Get(currentIt)->GetFieldAsInt(m_ClassKey)] + m_ClassesProbValidation[vectorData->Get(currentIt)->GetFieldAsInt(m_ClassKey)])
261 validationListSample->PushBack(it.Get());
262 validationListLabel->PushBack(vectorData->Get(currentIt)->GetFieldAsInt(m_ClassKey));
263 m_ClassesSamplesNumberValidation[vectorData->Get(currentIt)->GetFieldAsInt(m_ClassKey)] += 1;
271 assert(trainingListSample->Size() == trainingListLabel->Size());
272 assert(validationListSample->Size() == validationListLabel->Size());
273 this->UpdateProgress(1.0f);
276 template <
class TImage,
class TVectorData>
279 m_ClassesSize.clear();
282 typename VectorDataType::ConstPointer vectorData = this->GetInputVectorData();
285 auto itVectorPair = vectorData->GetIteratorPair();
286 auto currentIt = itVectorPair.first;
287 for (; currentIt != itVectorPair.second; ++currentIt)
289 typename DataNodeType::Pointer datanode = vectorData->Get(currentIt);
290 if (datanode->IsPolygonFeature())
292 double area = GetPolygonAreaInPixelsUnits(datanode, image);
293 m_ClassesSize[datanode->GetFieldAsInt(m_ClassKey)] += area;
296 m_NumberOfClasses = m_ClassesSize.size();
299 template <
class TImage,
class TVectorData>
302 m_ClassesProbTraining.clear();
303 m_ClassesProbValidation.clear();
306 if (m_ClassesSize.empty())
308 itkGenericExceptionMacro(<<
"No training sample found inside image");
312 double minSize = itk::NumericTraits<double>::max();
313 for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesSize.begin(); itmap != m_ClassesSize.end(); ++itmap)
315 if (minSize > itmap->second)
317 minSize = itmap->second;
322 double minSizeTraining = minSize * (1.0 - m_ValidationTrainingProportion);
323 double minSizeValidation = minSize * m_ValidationTrainingProportion;
328 if ((m_MaxTrainingSize != -1) && (m_MaxTrainingSize < minSizeTraining))
330 minSizeTraining = m_MaxTrainingSize;
332 if ((m_MaxValidationSize != -1) && (m_MaxValidationSize < minSizeValidation))
334 minSizeValidation = m_MaxValidationSize;
338 for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesSize.begin(); itmap != m_ClassesSize.end(); ++itmap)
342 m_ClassesProbTraining[itmap->first] = minSizeTraining / itmap->second;
343 m_ClassesProbValidation[itmap->first] = minSizeValidation / itmap->second;
347 long int maxSizeT = (itmap->second) * (1.0 - m_ValidationTrainingProportion);
348 long int maxSizeV = (itmap->second) * m_ValidationTrainingProportion;
351 double correctionRatioTrain = 1.0;
352 if ((m_MaxTrainingSize > -1) && (m_MaxTrainingSize < maxSizeT))
354 correctionRatioTrain = (double)(m_MaxTrainingSize) / (double)(maxSizeT);
356 double correctionRatioValid = 1.0;
357 if ((m_MaxValidationSize > -1) && (m_MaxValidationSize < maxSizeV))
359 correctionRatioValid = (double)(m_MaxValidationSize) / (double)(maxSizeV);
361 double correctionRatio = std::min(correctionRatioTrain, correctionRatioValid);
362 m_ClassesProbTraining[itmap->first] = correctionRatio * (1.0 - m_ValidationTrainingProportion);
363 m_ClassesProbValidation[itmap->first] = correctionRatio * m_ValidationTrainingProportion;
367 template <
class TImage,
class TVectorData>
370 const double pixelArea = std::abs(image->GetSignedSpacing()[0] * image->GetSignedSpacing()[1]);
374 double area = exteriorRing->GetArea() / pixelArea;
378 for (
typename PolygonListType::Iterator interiorRing = interiorRings->Begin(); interiorRing != interiorRings->End(); ++interiorRing)
380 area -= interiorRing.Get()->GetArea() / pixelArea;
386 template <
class TImage,
class TVectorData>
389 os << indent <<
"* MaxTrainingSize: " << m_MaxTrainingSize <<
"\n";
390 os << indent <<
"* MaxValidationSize: " << m_MaxValidationSize <<
"\n";
391 os << indent <<
"* Proportion: " << m_ValidationTrainingProportion <<
"\n";
392 os << indent <<
"* Input data:\n";
393 if (m_ClassesSize.empty())
395 os << indent <<
"Empty\n";
399 for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesSize.begin(); itmap != m_ClassesSize.end(); ++itmap)
401 os << indent << itmap->first <<
": " << itmap->second <<
"\n";
405 os <<
"\n" << indent <<
"* Training set:\n";
406 if (m_ClassesProbTraining.empty())
408 os << indent <<
"Not computed\n";
412 os << indent <<
"** Selection probability:\n";
413 for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesProbTraining.begin(); itmap != m_ClassesProbTraining.end(); ++itmap)
415 os << indent << itmap->first <<
": " << itmap->second <<
"\n";
417 os << indent <<
"** Number of selected samples:\n";
418 for (std::map<ClassLabelType, int>::const_iterator itmap = m_ClassesSamplesNumberTraining.begin(); itmap != m_ClassesSamplesNumberTraining.end(); ++itmap)
420 os << indent << itmap->first <<
": " << itmap->second <<
"\n";
424 os <<
"\n" << indent <<
"* Validation set:\n";
425 if (m_ClassesProbValidation.empty())
427 os << indent <<
"Not computed\n";
431 os << indent <<
"** Selection probability:\n";
432 for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesProbValidation.begin(); itmap != m_ClassesProbValidation.end(); ++itmap)
434 os << indent << itmap->first <<
": " << itmap->second <<
"\n";
436 os << indent <<
"** Number of selected samples:\n";
437 for (std::map<ClassLabelType, int>::const_iterator itmap = m_ClassesSamplesNumberValidation.begin(); itmap != m_ClassesSamplesNumberValidation.end();
440 os << indent << itmap->first <<
": " << itmap->second <<
"\n";
ListLabelType * GetValidationListLabel()
DataObjectPointer MakeOutput(DataObjectPointerArraySizeType idx) override
void SetInput(const ImageType *)
ListSampleType::Pointer ListSamplePointerType
VectorDataType::Pointer VectorDataPointerType
itk::ProcessObject::DataObjectPointerArraySizeType DataObjectPointerArraySizeType
void GenerateData(void) override
itk::DataObject::Pointer DataObjectPointer
ListSampleType * GetTrainingListSample()
double GetPolygonAreaInPixelsUnits(DataNodeType *polygonDataNode, ImageType *image)
const VectorDataType * GetInputVectorData() const
void PrintSelf(std::ostream &os, itk::Indent indent) const override
ImageType::Pointer ImagePointerType
ListLabelType::Pointer ListLabelPointerType
TVectorData VectorDataType
void GenerateInputRequestedRegion(void) override
void SetInputVectorData(const VectorDataType *)
itk::Statistics::ListSample< SampleType > ListSampleType
DataNodeType::PolygonListPointerType PolygonListPointerType
const ImageType * GetInput() const
ListSampleType * GetValidationListSample()
ListLabelType * GetTrainingListLabel()
void GenerateClassStatistics()
VectorDataType::DataNodeType DataNodeType
DataNodeType::PolygonPointerType PolygonPointerType
void ComputeClassSelectionProbability()
RandomGeneratorType::Pointer m_RandomGenerator
itk::Statistics::ListSample< LabelType > ListLabelType
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
ImageType::RegionType TransformPhysicalRegionToIndexRegion(const RemoteSensingRegionType ®ion, const ImageType *image)