OTB  10.0.0
Orfeo Toolbox
otbMachineLearningModel.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2024 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbMachineLearningModel_h
22 #define otbMachineLearningModel_h
23 
24 #include "itkObject.h"
25 #include "itkListSample.h"
27 
28 namespace otb
29 {
30 
69 template <class TInputValue, class TTargetValue, class TConfidenceValue = double>
70 class ITK_EXPORT MachineLearningModel : public itk::Object
71 {
72 public:
76  typedef itk::Object Superclass;
77  typedef itk::SmartPointer<Self> Pointer;
78  typedef itk::SmartPointer<const Self> ConstPointer;
80 
85  typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType;
87 
92  typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType;
94 
98  typedef itk::Statistics::ListSample<ConfidenceSampleType> ConfidenceListSampleType;
99 
100 
101  typedef itk::VariableLengthVector<double> ProbaSampleType;
102  typedef itk::Statistics::ListSample<ProbaSampleType> ProbaListSampleType;
105 
107  itkTypeMacro(MachineLearningModel, itk::Object);
109 
111  virtual void Train() = 0;
112 
119  TargetSampleType Predict(const InputSampleType& input, ConfidenceValueType* quality = nullptr, ProbaSampleType* proba = nullptr) const;
120 
123  itkSetMacro(Dimension, unsigned int);
124  itkGetMacro(Dimension, unsigned int);
126 
127 
136  typename TargetListSampleType::Pointer PredictBatch(const InputListSampleType* input, ConfidenceListSampleType* quality = nullptr,
137  ProbaListSampleType* proba = nullptr) const;
138 
141 
143  virtual void Save(const std::string& filename, const std::string& name = "") = 0;
144 
146  virtual void Load(const std::string& filename, const std::string& name = "") = 0;
148 
151 
153  virtual bool CanReadFile(const std::string&) = 0;
154 
156  virtual bool CanWriteFile(const std::string&) = 0;
158 
160  bool HasConfidenceIndex() const
161  {
162  return m_ConfidenceIndex;
163  }
164 
166  bool HasProbaIndex() const
167  {
168  return m_ProbaIndex;
169  }
170 
173  itkSetObjectMacro(InputListSample, InputListSampleType);
175  itkGetConstObjectMacro(InputListSample, InputListSampleType);
177 
178 
181 
183  itkSetObjectMacro(TargetListSample, TargetListSampleType);
184 
188 
190 
193  itkGetMacro(RegressionMode, bool);
194  void SetRegressionMode(bool flag);
196 
197 
198 protected:
201 
203  ~MachineLearningModel() override = default;
204 
206  void PrintSelf(std::ostream& os, itk::Indent indent) const override;
207 
209  typename InputListSampleType::Pointer m_InputListSample;
210 
212  typename InputListSampleType::Pointer m_ValidationListSample;
213 
215  typename TargetListSampleType::Pointer m_TargetListSample;
216 
217  typename ConfidenceListSampleType::Pointer m_ConfidenceListSample;
218 
221 
226 
229 
232 
235 
237  unsigned int m_Dimension;
238 
239 private:
255  virtual void DoPredictBatch(const InputListSampleType* input, const unsigned int& startIndex, const unsigned int& size, TargetListSampleType* target,
256  ConfidenceListSampleType* quality = nullptr, ProbaListSampleType* proba = nullptr) const;
257 
264  virtual TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType* quality = nullptr, ProbaSampleType* proba = nullptr) const = 0;
265 
266  MachineLearningModel(const Self&) = delete;
267  void operator=(const Self&) = delete;
268 };
269 } // end namespace otb
270 
271 #ifndef OTB_MANUAL_INSTANTIATION
273 #endif
274 
275 #endif
MachineLearningModel is the base class for all classifier objects (SVM, KNN, Random Forests,...
MLMSampleTraits< TInputValue >::ValueType InputValueType
virtual void Save(const std::string &filename, const std::string &name="")=0
virtual void Train()=0
itk::Statistics::ListSample< ConfidenceSampleType > ConfidenceListSampleType
itk::Statistics::ListSample< ProbaSampleType > ProbaListSampleType
TargetListSampleType::Pointer m_TargetListSample
MLMTargetTraits< TConfidenceValue >::SampleType ConfidenceSampleType
~MachineLearningModel() override=default
ConfidenceListSampleType::Pointer m_ConfidenceListSample
InputListSampleType::Pointer m_InputListSample
itkGetObjectMacro(ConfidenceListSample, ConfidenceListSampleType)
itk::SmartPointer< Self > Pointer
itk::Statistics::ListSample< TargetSampleType > TargetListSampleType
itk::SmartPointer< const Self > ConstPointer
MLMTargetTraits< TConfidenceValue >::ValueType ConfidenceValueType
itkGetObjectMacro(InputListSample, InputListSampleType)
virtual bool CanReadFile(const std::string &)=0
itk::VariableLengthVector< double > ProbaSampleType
MLMTargetTraits< TTargetValue >::ValueType TargetValueType
MachineLearningModel(const Self &)=delete
itk::Statistics::ListSample< InputSampleType > InputListSampleType
itkGetObjectMacro(TargetListSample, TargetListSampleType)
void operator=(const Self &)=delete
MLMSampleTraits< TInputValue >::SampleType InputSampleType
virtual TargetSampleType DoPredict(const InputSampleType &input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const =0
virtual void Load(const std::string &filename, const std::string &name="")=0
virtual bool CanWriteFile(const std::string &)=0
InputListSampleType::Pointer m_ValidationListSample
MLMTargetTraits< TTargetValue >::SampleType TargetSampleType
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.