OTB  10.0.0
Orfeo Toolbox
otbListSampleGenerator.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2024 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbListSampleGenerator_hxx
22 #define otbListSampleGenerator_hxx
23 
24 #include "otbListSampleGenerator.h"
25 
26 #include "itkImageRegionConstIteratorWithIndex.h"
28 
29 #include "otbMacro.h"
30 
31 namespace otb
32 {
33 
34 template <class TImage, class TVectorData>
36  : m_MaxTrainingSize(-1),
37  m_MaxValidationSize(-1),
38  m_ValidationTrainingProportion(0.0),
39  m_BoundByMin(true),
40  m_PolygonEdgeInclusion(false),
41  m_NumberOfClasses(0),
42  m_ClassKey("Class"),
43  m_ClassMinSize(-1)
44 {
45  this->SetNumberOfRequiredInputs(2);
46  this->SetNumberOfRequiredOutputs(4);
47 
48  // Register the outputs
49  this->itk::ProcessObject::SetNthOutput(0, this->MakeOutput(0).GetPointer());
50  this->itk::ProcessObject::SetNthOutput(1, this->MakeOutput(1).GetPointer());
51  this->itk::ProcessObject::SetNthOutput(2, this->MakeOutput(2).GetPointer());
52  this->itk::ProcessObject::SetNthOutput(3, this->MakeOutput(3).GetPointer());
53 
54  m_RandomGenerator = RandomGeneratorType::GetInstance();
55 }
56 
57 template <class TImage, class TVectorData>
59 {
60  this->ProcessObject::SetNthInput(0, const_cast<ImageType*>(image));
61 }
62 
63 template <class TImage, class TVectorData>
65 {
66  if (this->GetNumberOfInputs() < 1)
67  {
68  return nullptr;
69  }
70 
71  return static_cast<const ImageType*>(this->ProcessObject::GetInput(0));
72 }
73 
74 template <class TImage, class TVectorData>
76 {
77  this->ProcessObject::SetNthInput(1, const_cast<VectorDataType*>(vectorData));
78 
79  // printVectorData(vectorData);
80 }
81 
82 template <class TImage, class TVectorData>
84 {
85  if (this->GetNumberOfInputs() < 2)
86  {
87  return nullptr;
88  }
89 
90  return static_cast<const VectorDataType*>(this->ProcessObject::GetInput(1));
91 }
92 
93 template <class TImage, class TVectorData>
95 {
96  DataObjectPointer output;
97  switch (idx)
98  {
99  case 0:
100  output = static_cast<itk::DataObject*>(ListSampleType::New().GetPointer());
101  break;
102  case 1:
103  output = static_cast<itk::DataObject*>(ListLabelType::New().GetPointer());
104  break;
105  case 2:
106  output = static_cast<itk::DataObject*>(ListSampleType::New().GetPointer());
107  break;
108  case 3:
109  output = static_cast<itk::DataObject*>(ListLabelType::New().GetPointer());
110  break;
111  default:
112  output = static_cast<itk::DataObject*>(ListSampleType::New().GetPointer());
113  break;
114  }
115  return output;
116 }
117 // Get the Training ListSample
118 template <class TImage, class TVectorData>
120 {
121  return dynamic_cast<ListSampleType*>(this->itk::ProcessObject::GetOutput(0));
122 }
123 // Get the Training label ListSample
124 template <class TImage, class TVectorData>
126 {
127  return dynamic_cast<ListLabelType*>(this->itk::ProcessObject::GetOutput(1));
128 }
129 
130 // Get the validation ListSample
131 template <class TImage, class TVectorData>
133 {
134  return dynamic_cast<ListSampleType*>(this->itk::ProcessObject::GetOutput(2));
135 }
136 
137 
138 // Get the validation label ListSample
139 template <class TImage, class TVectorData>
141 {
142  return dynamic_cast<ListLabelType*>(this->itk::ProcessObject::GetOutput(3));
143 }
144 
145 template <class TImage, class TVectorData>
147 {
148  ImagePointerType img = static_cast<ImageType*>(this->ProcessObject::GetInput(0));
149 
150  if (img.IsNotNull())
151  {
152 
153  // Requested regions will be generated during GenerateData
154  // call. For now request an empty region so as to avoid requesting
155  // the largest possible region (fixes bug #943 )
156  typename ImageType::RegionType dummyRegion;
157  typename ImageType::SizeType dummySize;
158  dummySize.Fill(0);
159  dummyRegion.SetSize(dummySize);
160  img->SetRequestedRegion(dummyRegion);
161  }
162 }
163 
164 
165 template <class TImage, class TVectorData>
167 {
168  // Get the inputs
169  ImagePointerType image = const_cast<ImageType*>(this->GetInput());
170  VectorDataPointerType vectorData = const_cast<VectorDataType*>(this->GetInputVectorData());
171 
172  // Get the outputs
173  ListSamplePointerType trainingListSample = this->GetTrainingListSample();
174  ListLabelPointerType trainingListLabel = this->GetTrainingListLabel();
175  ListSamplePointerType validationListSample = this->GetValidationListSample();
176  ListLabelPointerType validationListLabel = this->GetValidationListLabel();
177 
178  // Gather some information about the relative size of the classes
179  // We would like to have the same number of samples per class
180  this->GenerateClassStatistics();
181 
182  this->ComputeClassSelectionProbability();
183 
184  // Clear the sample lists
185  trainingListSample->Clear();
186  trainingListLabel->Clear();
187  validationListSample->Clear();
188  validationListLabel->Clear();
189 
190  // Set MeasurementVectorSize for each sample list
191  trainingListSample->SetMeasurementVectorSize(image->GetNumberOfComponentsPerPixel());
192  // stores label as integers,so put the size to 1
193  trainingListLabel->SetMeasurementVectorSize(1);
194  validationListSample->SetMeasurementVectorSize(image->GetNumberOfComponentsPerPixel());
195  // stores label as integers,so put the size to 1
196  validationListLabel->SetMeasurementVectorSize(1);
197 
198  m_ClassesSamplesNumberTraining.clear();
199  m_ClassesSamplesNumberValidation.clear();
200 
201  typename ImageType::RegionType imageLargestRegion = image->GetLargestPossibleRegion();
202 
203  auto itVectorPair = vectorData->GetIteratorPair();
204  auto currentIt = itVectorPair.first;
205  for (; currentIt != itVectorPair.second; ++currentIt)
206  {
207  if (vectorData->Get(currentIt)->IsPolygonFeature())
208  {
209  PolygonPointerType exteriorRing = vectorData->Get(currentIt)->GetPolygonExteriorRing();
210 
211  typename ImageType::RegionType polygonRegion = otb::TransformPhysicalRegionToIndexRegion(exteriorRing->GetBoundingRegion(), image.GetPointer());
212 
213  const bool hasIntersection = polygonRegion.Crop(imageLargestRegion);
214  if (!hasIntersection)
215  {
216  continue;
217  }
218 
219  image->SetRequestedRegion(polygonRegion);
220  image->PropagateRequestedRegion();
221  image->UpdateOutputData();
222 
223  typedef itk::ImageRegionConstIteratorWithIndex<ImageType> IteratorType;
224  IteratorType it(image, polygonRegion);
225 
226  for (it.GoToBegin(); !it.IsAtEnd(); ++it)
227  {
228  itk::ContinuousIndex<double, 2> point;
229  image->TransformIndexToPhysicalPoint(it.GetIndex(), point);
230 
231  if (exteriorRing->IsInside(point) || (this->GetPolygonEdgeInclusion() && exteriorRing->IsOnEdge(point)))
232  {
233  PolygonListPointerType interiorRings = vectorData->Get(currentIt)->GetPolygonInteriorRings();
234 
235  bool isInsideInteriorRing = false;
236  for (typename PolygonListType::Iterator interiorRing = interiorRings->Begin(); interiorRing != interiorRings->End(); ++interiorRing)
237  {
238  if (interiorRing.Get()->IsInside(point) || (this->GetPolygonEdgeInclusion() && interiorRing.Get()->IsOnEdge(point)))
239  {
240  isInsideInteriorRing = true;
241  break;
242  }
243  }
244  if (isInsideInteriorRing)
245  {
246  continue; // skip this pixel and continue
247  }
248 
249  double randomValue = m_RandomGenerator->GetUniformVariate(0.0, 1.0);
250  if (randomValue < m_ClassesProbTraining[vectorData->Get(currentIt)->GetFieldAsInt(m_ClassKey)])
251  {
252  // Add the sample to the training list
253  trainingListSample->PushBack(it.Get());
254  trainingListLabel->PushBack(vectorData->Get(currentIt)->GetFieldAsInt(m_ClassKey));
255  m_ClassesSamplesNumberTraining[vectorData->Get(currentIt)->GetFieldAsInt(m_ClassKey)] += 1;
256  }
257  else if (randomValue <
258  m_ClassesProbTraining[vectorData->Get(currentIt)->GetFieldAsInt(m_ClassKey)] + m_ClassesProbValidation[vectorData->Get(currentIt)->GetFieldAsInt(m_ClassKey)])
259  {
260  // Add the sample to the validation list
261  validationListSample->PushBack(it.Get());
262  validationListLabel->PushBack(vectorData->Get(currentIt)->GetFieldAsInt(m_ClassKey));
263  m_ClassesSamplesNumberValidation[vectorData->Get(currentIt)->GetFieldAsInt(m_ClassKey)] += 1;
264  }
265  // Note: some samples may not be used at all
266  }
267  }
268  }
269  }
270 
271  assert(trainingListSample->Size() == trainingListLabel->Size());
272  assert(validationListSample->Size() == validationListLabel->Size());
273  this->UpdateProgress(1.0f);
274 }
275 
276 template <class TImage, class TVectorData>
278 {
279  m_ClassesSize.clear();
280 
281  ImageType* image = const_cast<ImageType*>(this->GetInput());
282  typename VectorDataType::ConstPointer vectorData = this->GetInputVectorData();
283 
284  // Compute cumulative area of all polygons of each class
285  auto itVectorPair = vectorData->GetIteratorPair();
286  auto currentIt = itVectorPair.first;
287  for (; currentIt != itVectorPair.second; ++currentIt)
288  {
289  typename DataNodeType::Pointer datanode = vectorData->Get(currentIt);
290  if (datanode->IsPolygonFeature())
291  {
292  double area = GetPolygonAreaInPixelsUnits(datanode, image);
293  m_ClassesSize[datanode->GetFieldAsInt(m_ClassKey)] += area;
294  }
295  }
296  m_NumberOfClasses = m_ClassesSize.size();
297 }
298 
299 template <class TImage, class TVectorData>
301 {
302  m_ClassesProbTraining.clear();
303  m_ClassesProbValidation.clear();
304 
305  // Sanity check
306  if (m_ClassesSize.empty())
307  {
308  itkGenericExceptionMacro(<< "No training sample found inside image");
309  }
310 
311  // Go through the classes size to find the smallest one
312  double minSize = itk::NumericTraits<double>::max();
313  for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesSize.begin(); itmap != m_ClassesSize.end(); ++itmap)
314  {
315  if (minSize > itmap->second)
316  {
317  minSize = itmap->second;
318  }
319  }
320 
321  // Apply the proportion between training and validation samples (all training by default)
322  double minSizeTraining = minSize * (1.0 - m_ValidationTrainingProportion);
323  double minSizeValidation = minSize * m_ValidationTrainingProportion;
324 
325  // Apply the limit if specified by the user
326  if (m_BoundByMin)
327  {
328  if ((m_MaxTrainingSize != -1) && (m_MaxTrainingSize < minSizeTraining))
329  {
330  minSizeTraining = m_MaxTrainingSize;
331  }
332  if ((m_MaxValidationSize != -1) && (m_MaxValidationSize < minSizeValidation))
333  {
334  minSizeValidation = m_MaxValidationSize;
335  }
336  }
337  // Compute the probability selection for each class
338  for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesSize.begin(); itmap != m_ClassesSize.end(); ++itmap)
339  {
340  if (m_BoundByMin)
341  {
342  m_ClassesProbTraining[itmap->first] = minSizeTraining / itmap->second;
343  m_ClassesProbValidation[itmap->first] = minSizeValidation / itmap->second;
344  }
345  else
346  {
347  long int maxSizeT = (itmap->second) * (1.0 - m_ValidationTrainingProportion);
348  long int maxSizeV = (itmap->second) * m_ValidationTrainingProportion;
349 
350  // Check if max sizes respect the maximum bounds
351  double correctionRatioTrain = 1.0;
352  if ((m_MaxTrainingSize > -1) && (m_MaxTrainingSize < maxSizeT))
353  {
354  correctionRatioTrain = (double)(m_MaxTrainingSize) / (double)(maxSizeT);
355  }
356  double correctionRatioValid = 1.0;
357  if ((m_MaxValidationSize > -1) && (m_MaxValidationSize < maxSizeV))
358  {
359  correctionRatioValid = (double)(m_MaxValidationSize) / (double)(maxSizeV);
360  }
361  double correctionRatio = std::min(correctionRatioTrain, correctionRatioValid);
362  m_ClassesProbTraining[itmap->first] = correctionRatio * (1.0 - m_ValidationTrainingProportion);
363  m_ClassesProbValidation[itmap->first] = correctionRatio * m_ValidationTrainingProportion;
364  }
365  }
366 }
367 template <class TImage, class TVectorData>
369 {
370  const double pixelArea = std::abs(image->GetSignedSpacing()[0] * image->GetSignedSpacing()[1]);
371 
372  // Compute area of exterior ring in pixels
373  PolygonPointerType exteriorRing = polygonDataNode->GetPolygonExteriorRing();
374  double area = exteriorRing->GetArea() / pixelArea;
375 
376  // Remove contribution of all interior rings
377  PolygonListPointerType interiorRings = polygonDataNode->GetPolygonInteriorRings();
378  for (typename PolygonListType::Iterator interiorRing = interiorRings->Begin(); interiorRing != interiorRings->End(); ++interiorRing)
379  {
380  area -= interiorRing.Get()->GetArea() / pixelArea;
381  }
382 
383  return area;
384 }
385 
386 template <class TImage, class TVectorData>
387 void ListSampleGenerator<TImage, TVectorData>::PrintSelf(std::ostream& os, itk::Indent indent) const
388 {
389  os << indent << "* MaxTrainingSize: " << m_MaxTrainingSize << "\n";
390  os << indent << "* MaxValidationSize: " << m_MaxValidationSize << "\n";
391  os << indent << "* Proportion: " << m_ValidationTrainingProportion << "\n";
392  os << indent << "* Input data:\n";
393  if (m_ClassesSize.empty())
394  {
395  os << indent << "Empty\n";
396  }
397  else
398  {
399  for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesSize.begin(); itmap != m_ClassesSize.end(); ++itmap)
400  {
401  os << indent << itmap->first << ": " << itmap->second << "\n";
402  }
403  }
404 
405  os << "\n" << indent << "* Training set:\n";
406  if (m_ClassesProbTraining.empty())
407  {
408  os << indent << "Not computed\n";
409  }
410  else
411  {
412  os << indent << "** Selection probability:\n";
413  for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesProbTraining.begin(); itmap != m_ClassesProbTraining.end(); ++itmap)
414  {
415  os << indent << itmap->first << ": " << itmap->second << "\n";
416  }
417  os << indent << "** Number of selected samples:\n";
418  for (std::map<ClassLabelType, int>::const_iterator itmap = m_ClassesSamplesNumberTraining.begin(); itmap != m_ClassesSamplesNumberTraining.end(); ++itmap)
419  {
420  os << indent << itmap->first << ": " << itmap->second << "\n";
421  }
422  }
423 
424  os << "\n" << indent << "* Validation set:\n";
425  if (m_ClassesProbValidation.empty())
426  {
427  os << indent << "Not computed\n";
428  }
429  else
430  {
431  os << indent << "** Selection probability:\n";
432  for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesProbValidation.begin(); itmap != m_ClassesProbValidation.end(); ++itmap)
433  {
434  os << indent << itmap->first << ": " << itmap->second << "\n";
435  }
436  os << indent << "** Number of selected samples:\n";
437  for (std::map<ClassLabelType, int>::const_iterator itmap = m_ClassesSamplesNumberValidation.begin(); itmap != m_ClassesSamplesNumberValidation.end();
438  ++itmap)
439  {
440  os << indent << itmap->first << ": " << itmap->second << "\n";
441  }
442  }
443 }
444 }
445 
446 #endif
ListLabelType * GetValidationListLabel()
DataObjectPointer MakeOutput(DataObjectPointerArraySizeType idx) override
void SetInput(const ImageType *)
ListSampleType::Pointer ListSamplePointerType
VectorDataType::Pointer VectorDataPointerType
itk::ProcessObject::DataObjectPointerArraySizeType DataObjectPointerArraySizeType
itk::DataObject::Pointer DataObjectPointer
ListSampleType * GetTrainingListSample()
double GetPolygonAreaInPixelsUnits(DataNodeType *polygonDataNode, ImageType *image)
const VectorDataType * GetInputVectorData() const
void PrintSelf(std::ostream &os, itk::Indent indent) const override
ImageType::Pointer ImagePointerType
ListLabelType::Pointer ListLabelPointerType
void GenerateInputRequestedRegion(void) override
void SetInputVectorData(const VectorDataType *)
itk::Statistics::ListSample< SampleType > ListSampleType
DataNodeType::PolygonListPointerType PolygonListPointerType
const ImageType * GetInput() const
ListSampleType * GetValidationListSample()
VectorDataType::DataNodeType DataNodeType
DataNodeType::PolygonPointerType PolygonPointerType
RandomGeneratorType::Pointer m_RandomGenerator
itk::Statistics::ListSample< LabelType > ListLabelType
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
ImageType::RegionType TransformPhysicalRegionToIndexRegion(const RemoteSensingRegionType &region, const ImageType *image)