OTB  10.0.0
Orfeo Toolbox
otbLearningApplicationBase.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2024 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbLearningApplicationBase_h
22 #define otbLearningApplicationBase_h
23 
24 #include "otbConfigure.h"
25 
26 #include "otbWrapperApplication.h"
27 
28 
29 // ListSample
30 #include "itkListSample.h"
31 #include "itkVariableLengthVector.h"
32 
33 // Estimator
35 #include <string>
36 
37 namespace otb
38 {
39 namespace Wrapper
40 {
41 
74 template <class TInputValue, class TOutputValue>
76 {
77 public:
81  typedef itk::SmartPointer<Self> Pointer;
82  typedef itk::SmartPointer<const Self> ConstPointer;
83 
85  itkTypeMacro(LearningApplicationBase, otb::Application);
86 
87  typedef TInputValue InputValueType;
88  typedef TOutputValue OutputValueType;
89 
92 
93  // Machine Learning models
97 
100 
104 
105  itkGetConstReferenceMacro(SupervisedClassifier, std::vector<std::string>);
106  itkGetConstReferenceMacro(UnsupervisedClassifier, std::vector<std::string>);
107 
109  {
112  };
113 
120 
121 protected:
123 
124  ~LearningApplicationBase() override;
125 
128  void Train(typename ListSampleType::Pointer trainingListSample, typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath);
129 
131  typename TargetListSampleType::Pointer Classify(typename ListSampleType::Pointer validationListSample, std::string modelPath);
132 
134  void DoInit() override;
135 
139 
140 private:
145  std::vector<std::string> m_SupervisedClassifier;
146 
149  std::vector<std::string> m_UnsupervisedClassifier;
150 
152 #ifdef OTB_USE_LIBSVM
153  void InitLibSVMParams();
154 
155  void TrainLibSVM(typename ListSampleType::Pointer trainingListSample, typename TargetListSampleType::Pointer trainingLabeledListSample,
156  std::string modelPath);
157 #endif
158 
159 #ifdef OTB_USE_OPENCV
160  void InitBoostParams();
161  void InitSVMParams();
162  void InitDecisionTreeParams();
163  void InitNeuralNetworkParams();
164  void InitNormalBayesParams();
165  void InitRandomForestsParams();
166  void InitKNNParams();
167 
168  void TrainBoost(typename ListSampleType::Pointer trainingListSample, typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath);
169  void TrainSVM(typename ListSampleType::Pointer trainingListSample, typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath);
170  void TrainDecisionTree(typename ListSampleType::Pointer trainingListSample, typename TargetListSampleType::Pointer trainingLabeledListSample,
171  std::string modelPath);
172  void TrainNeuralNetwork(typename ListSampleType::Pointer trainingListSample, typename TargetListSampleType::Pointer trainingLabeledListSample,
173  std::string modelPath);
174  void TrainNormalBayes(typename ListSampleType::Pointer trainingListSample, typename TargetListSampleType::Pointer trainingLabeledListSample,
175  std::string modelPath);
176  void TrainRandomForests(typename ListSampleType::Pointer trainingListSample, typename TargetListSampleType::Pointer trainingLabeledListSample,
177  std::string modelPath);
178  void TrainKNN(typename ListSampleType::Pointer trainingListSample, typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath);
179 #endif
180 
181 #ifdef OTB_USE_SHARK
182  void InitSharkRandomForestsParams();
183  void TrainSharkRandomForests(typename ListSampleType::Pointer trainingListSample, typename TargetListSampleType::Pointer trainingLabeledListSample,
184  std::string modelPath);
185  void InitSharkKMeansParams();
186  void TrainSharkKMeans(typename ListSampleType::Pointer trainingListSample, typename TargetListSampleType::Pointer trainingLabeledListSample,
187  std::string modelPath);
188 #endif
190 };
191 }
192 }
193 
194 #ifndef OTB_MANUAL_INSTANTIATION
196 #ifdef OTB_USE_OPENCV
197 #include "otbTrainBoost.hxx"
198 #include "otbTrainDecisionTree.hxx"
199 #include "otbTrainKNN.hxx"
200 #include "otbTrainNeuralNetwork.hxx"
201 #include "otbTrainNormalBayes.hxx"
202 #include "otbTrainRandomForests.hxx"
203 #include "otbTrainSVM.hxx"
204 #endif
205 #ifdef OTB_USE_LIBSVM
206 #include "otbTrainLibSVM.hxx"
207 #endif
208 #ifdef OTB_USE_SHARK
210 #include "otbTrainSharkKMeans.hxx"
211 #endif
212 #endif
213 
214 #endif
Creation of object instance using object factory.
MachineLearningModelType::Pointer MachineLearningModelTypePointer
MachineLearningModel is the base class for all classifier objects (SVM, KNN, Random Forests,...
itk::Statistics::ListSample< TargetSampleType > TargetListSampleType
MLMTargetTraits< TTargetValue >::ValueType TargetValueType
itk::Statistics::ListSample< InputSampleType > InputListSampleType
MLMSampleTraits< TInputValue >::SampleType InputSampleType
MLMTargetTraits< TTargetValue >::SampleType TargetSampleType
Creation of an "otb" vector image which contains metadata.
Superclass::PixelType PixelType
This class represent an application TODO.
LearningApplicationBase is the base class for application that use machine learning model.
void Train(typename ListSampleType::Pointer trainingListSample, typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath)
otb::VectorImage< InputValueType > SampleImageType
ModelType::TargetListSampleType TargetListSampleType
ModelFactoryType::MachineLearningModelType ModelType
ModelType::InputListSampleType ListSampleType
TargetListSampleType::Pointer Classify(typename ListSampleType::Pointer validationListSample, std::string modelPath)
ModelFactoryType::MachineLearningModelTypePointer ModelPointerType
otb::MachineLearningModelFactory< InputValueType, OutputValueType > ModelFactoryType
itk::SmartPointer< const Self > ConstPointer
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.