OTB  10.0.0
Orfeo Toolbox
otbRandomForestsMachineLearningModel.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2024 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbRandomForestsMachineLearningModel_h
22 #define otbRandomForestsMachineLearningModel_h
23 
24 #include "otbRequiresOpenCVCheck.h"
25 
26 #include "itkLightObject.h"
27 #include "itkFixedArray.h"
29 #include "itkVariableSizeMatrix.h"
30 #include "otbCvRTreesWrapper.h"
31 
32 namespace otb
33 {
34 
35 template <class TInputValue, class TTargetValue>
36 class ITK_EXPORT RandomForestsMachineLearningModel : public MachineLearningModel<TInputValue, TTargetValue>
37 {
38 public:
42  typedef itk::SmartPointer<Self> Pointer;
43  typedef itk::SmartPointer<const Self> ConstPointer;
44 
45  typedef typename Superclass::InputValueType InputValueType;
46  typedef typename Superclass::InputSampleType InputSampleType;
47  typedef typename Superclass::InputListSampleType InputListSampleType;
48  typedef typename Superclass::TargetValueType TargetValueType;
50  typedef typename Superclass::TargetListSampleType TargetListSampleType;
51  typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
52  typedef typename Superclass::ProbaSampleType ProbaSampleType;
53  // Other
54  typedef itk::VariableSizeMatrix<float> VariableImportanceMatrixType;
55 
56 
57  // opencv typedef
59 
61  itkNewMacro(Self);
64 
66  void Train() override;
67 
69  void Save(const std::string& filename, const std::string& name = "") override;
70 
72  void Load(const std::string& filename, const std::string& name = "") override;
73 
76 
78  bool CanReadFile(const std::string&) override;
79 
81  bool CanWriteFile(const std::string&) override;
83 
84  // Setters of RT parameters (documentation get from opencv doxygen 2.4)
85  itkGetMacro(MaxDepth, int);
86  itkSetMacro(MaxDepth, int);
87 
88  itkGetMacro(MinSampleCount, int);
89  itkSetMacro(MinSampleCount, int);
90 
91  itkGetMacro(RegressionAccuracy, double);
92  itkSetMacro(RegressionAccuracy, double);
93 
94  itkGetMacro(ComputeSurrogateSplit, bool);
95  itkSetMacro(ComputeSurrogateSplit, bool);
96 
97  itkGetMacro(MaxNumberOfCategories, int);
98  itkSetMacro(MaxNumberOfCategories, int);
99 
100  std::vector<float> GetPriors() const
101  {
102  return m_Priors;
103  }
104 
105  void SetPriors(const std::vector<float>& priors)
106  {
107  m_Priors = priors;
108  }
109 
110  itkGetMacro(CalculateVariableImportance, bool);
111  itkSetMacro(CalculateVariableImportance, bool);
112 
113  itkGetMacro(MaxNumberOfVariables, int);
114  itkSetMacro(MaxNumberOfVariables, int);
115 
116  itkGetMacro(MaxNumberOfTrees, int);
117  itkSetMacro(MaxNumberOfTrees, int);
118 
119  itkGetMacro(ForestAccuracy, float);
120  itkSetMacro(ForestAccuracy, float);
121 
122  itkGetMacro(TerminationCriteria, int);
123  itkSetMacro(TerminationCriteria, int);
124 
125  itkGetMacro(ComputeMargin, bool);
126  itkSetMacro(ComputeMargin, bool);
127 
129  VariableImportanceMatrixType GetVariableImportance();
130 
131  float GetTrainError();
132 
133 protected:
136 
139 
141  TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType* quality = nullptr, ProbaSampleType* proba = nullptr) const override;
142 
144  void PrintSelf(std::ostream& os, itk::Indent indent) const override;
145 
146  /* /\** Input list sample *\/ */
147  /* typename InputListSampleType::Pointer m_InputListSample; */
148 
149  /* /\** Target list sample *\/ */
150  /* typename TargetListSampleType::Pointer m_TargetListSample; */
151 
152 private:
154  void operator=(const Self&) = delete;
155 
156  cv::Ptr<CvRTreesWrapper> m_RFModel;
157 
162 
166 
172 
188 
203  std::vector<float> m_Priors;
204 
208 
213 
220 
223 
226 
231 };
232 } // end namespace otb
234 
235 #ifndef OTB_MANUAL_INSTANTIATION
237 #endif
238 
239 #endif
Wrapper for OpenCV Random Trees.
MachineLearningModel is the base class for all classifier objects (SVM, KNN, Random Forests,...
MLMTargetTraits< TConfidenceValue >::ValueType ConfidenceValueType
itk::VariableLengthVector< double > ProbaSampleType
MLMSampleTraits< TInputValue >::SampleType InputSampleType
MLMTargetTraits< TTargetValue >::SampleType TargetSampleType
MachineLearningModel< TInputValue, TTargetValue > Superclass
void operator=(const Self &)=delete
itk::VariableSizeMatrix< float > VariableImportanceMatrixType
RandomForestsMachineLearningModel(const Self &)=delete
void SetPriors(const std::vector< float > &priors)
~RandomForestsMachineLearningModel() override=default
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.