OTB  10.0.0
Orfeo Toolbox
otbTrainImagesBase.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2024 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 #ifndef otbTrainImagesBase_h
21 #define otbTrainImagesBase_h
22 
26 
31 #include <string>
32 
33 namespace otb
34 {
35 namespace Wrapper
36 {
37 
48 {
49 public:
53  typedef itk::SmartPointer<Self> Pointer;
54  typedef itk::SmartPointer<const Self> ConstPointer;
55 
57  itkTypeMacro(TrainImagesBase, Superclass);
58 
61 
63 
64 protected:
65  typedef enum { CLASS, GEOMETRIC } SamplingStrategy;
66  struct SamplingRates;
67  class TrainFileNamesHandler;
68 
72  void InitIO();
73 
77  void InitSampling();
78 
81  void InitClassification();
84 
91  void ComputePolygonStatistics(FloatVectorImageListType* imageList, const std::vector<std::string>& vectorFileNames,
92  const std::vector<std::string>& statisticsFileNames);
93 
99  SamplingRates ComputeFinalMaximumSamplingRates(bool dedicatedValidation);
100 
101 
109  void ComputeSamplingRate(const std::vector<std::string>& statisticsFileNames, const std::string& ratesFileName, long maximum);
110 
117  void TrainModel(FloatVectorImageListType* imageList, const std::vector<std::string>& sampleTrainFileNames,
118  const std::vector<std::string>& sampleValidationFileNames);
119 
129  void SelectAndExtractSamples(FloatVectorImageType* image, std::string vectorFileName, std::string sampleFileName, std::string statisticsFileName,
130  std::string ratesFileName, SamplingStrategy strategy, std::string selectedField = "");
131 
140  void SelectAndExtractTrainSamples(const TrainFileNamesHandler& fileNames, FloatVectorImageListType* imageList, std::vector<std::string> vectorFileNames,
141  SamplingStrategy strategy, std::string selectedFieldName = "");
142 
143 
153  void SelectAndExtractValidationSamples(const TrainFileNamesHandler& fileNames, FloatVectorImageListType* imageList,
154  const std::vector<std::string>& validationVectorFileList = std::vector<std::string>());
155 
162  void SplitTrainingToValidationSamples(const TrainFileNamesHandler& fileNames, FloatVectorImageListType* imageList);
163 
164 private:
173  void SplitTrainingAndValidationSamples(FloatVectorImageType* image, std::string sampleFileName, std::string sampleTrainFileName,
174  std::string sampleValidFileName, std::string ratesTrainFileName);
175 
176 
177 protected:
179  {
180  long int fmt;
181  long int fmv;
182  };
183 
191  {
192  public:
193  void CreateTemporaryFileNames(std::string outModel, size_t nbInputs, bool dedicatedValidation)
194  {
195 
196  if (dedicatedValidation)
197  {
198  rateTrainOut = outModel + "_ratesTrain.csv";
199  }
200  else
201  {
202  rateTrainOut = outModel + "_rates.csv";
203  }
204 
205  rateValidOut = outModel + "_ratesValid.csv";
206  for (unsigned int i = 0; i < nbInputs; i++)
207  {
208  std::ostringstream oss;
209  oss << i + 1;
210  std::string strIndex(oss.str());
211  if (dedicatedValidation)
212  {
213  polyStatTrainOutputs.push_back(outModel + "_statsTrain_" + strIndex + ".xml");
214  polyStatValidOutputs.push_back(outModel + "_statsValid_" + strIndex + ".xml");
215  ratesTrainOutputs.push_back(outModel + "_ratesTrain_" + strIndex + ".csv");
216  ratesValidOutputs.push_back(outModel + "_ratesValid_" + strIndex + ".csv");
217  sampleOutputs.push_back(outModel + "_samplesTrain_" + strIndex + ".shp");
218  }
219  else
220  {
221  polyStatTrainOutputs.push_back(outModel + "_stats_" + strIndex + ".xml");
222  ratesTrainOutputs.push_back(outModel + "_rates_" + strIndex + ".csv");
223  sampleOutputs.push_back(outModel + "_samples_" + strIndex + ".shp");
224  }
225  sampleTrainOutputs.push_back(outModel + "_samplesTrain_" + strIndex + ".shp");
226  sampleValidOutputs.push_back(outModel + "_samplesValid_" + strIndex + ".shp");
227  }
228  }
229 
230  void clear()
231  {
232  for (unsigned int i = 0; i < polyStatTrainOutputs.size(); i++)
234  for (unsigned int i = 0; i < polyStatValidOutputs.size(); i++)
236  for (unsigned int i = 0; i < ratesTrainOutputs.size(); i++)
238  for (unsigned int i = 0; i < ratesValidOutputs.size(); i++)
240  for (unsigned int i = 0; i < sampleOutputs.size(); i++)
242  for (unsigned int i = 0; i < sampleTrainOutputs.size(); i++)
244  for (unsigned int i = 0; i < sampleValidOutputs.size(); i++)
246  for (unsigned int i = 0; i < tmpVectorFileList.size(); i++)
248  }
249 
250  public:
251  std::vector<std::string> polyStatTrainOutputs;
252  std::vector<std::string> polyStatValidOutputs;
253  std::vector<std::string> ratesTrainOutputs;
254  std::vector<std::string> ratesValidOutputs;
255  std::vector<std::string> sampleOutputs;
256  std::vector<std::string> sampleTrainOutputs;
257  std::vector<std::string> sampleValidOutputs;
258  std::vector<std::string> tmpVectorFileList;
259  std::string rateValidOut;
260  std::string rateTrainOut;
261 
262  private:
263  bool RemoveFile(std::string& filePath)
264  {
265  itksys::Status res;
266  if (itksys::SystemTools::FileExists(filePath))
267  {
268  size_t posExt = filePath.rfind('.');
269  if (posExt != std::string::npos && filePath.compare(posExt, std::string::npos, ".shp") == 0)
270  {
271  std::string shxPath = filePath.substr(0, posExt) + std::string(".shx");
272  std::string dbfPath = filePath.substr(0, posExt) + std::string(".dbf");
273  std::string prjPath = filePath.substr(0, posExt) + std::string(".prj");
274  RemoveFile(shxPath);
275  RemoveFile(dbfPath);
276  RemoveFile(prjPath);
277  }
278  res = itksys::SystemTools::RemoveFile(filePath);
279  }
280  bool status = res.GetKind() == itksys::Status::Kind::Success;
281  return status;
282  }
283  };
284 };
285 
286 } // end namespace Wrapper
287 } // end namespace otb
288 
289 #ifndef OTB_MANUAL_INSTANTIATION
290 #include "otbTrainImagesBase.hxx"
291 #endif
292 
293 #endif // otbTrainImagesBase_h
Extracts sample position from an image using a persistent filter.
This class is a generic all-purpose wrapping around an std::vector<itk::SmartPointer<ObjectType> >.
Definition: otbObjectList.h:41
std::map< std::string, TripletType > MapRateType
Creation of an "otb" vector image which contains metadata.
This class is a base class for composite applications.
void CreateTemporaryFileNames(std::string outModel, vcl_size_t nbInputs, bool dedicatedValidation)
Base class for the TrainImagesClassifier.
otb::SamplingRateCalculator::MapRateType MapRateType
void SplitTrainingAndValidationSamples(FloatVectorImageType *image, std::string sampleFileName, std::string sampleTrainFileName, std::string sampleValidFileName, std::string ratesTrainFileName)
void SplitTrainingToValidationSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList)
void SelectAndExtractSamples(FloatVectorImageType *image, std::string vectorFileName, std::string sampleFileName, std::string statisticsFileName, std::string ratesFileName, SamplingStrategy strategy, std::string selectedField="")
otb::OGRDataToSamplePositionFilter< FloatVectorImageType, UInt8ImageType, otb::PeriodicSampler > PeriodicSamplerType
itk::SmartPointer< Self > Pointer
void ComputePolygonStatistics(FloatVectorImageListType *imageList, const std::vector< std::string > &vectorFileNames, const std::vector< std::string > &statisticsFileNames)
void SelectAndExtractValidationSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList, const std::vector< std::string > &validationVectorFileList=std::vector< std::string >())
void SelectAndExtractTrainSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList, std::vector< std::string > vectorFileNames, SamplingStrategy strategy, std::string selectedFieldName="")
void ComputeSamplingRate(const std::vector< std::string > &statisticsFileNames, const std::string &ratesFileName, long maximum)
itk::SmartPointer< const Self > ConstPointer
void TrainModel(FloatVectorImageListType *imageList, const std::vector< std::string > &sampleTrainFileNames, const std::vector< std::string > &sampleValidationFileNames)
SamplingRates ComputeFinalMaximumSamplingRates(bool dedicatedValidation)
CompositeApplication Superclass
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.