OTB  10.0.0
Orfeo Toolbox
otbSharkRandomForestsMachineLearningModel.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2024 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbSharkRandomForestsMachineLearningModel_h
22 #define otbSharkRandomForestsMachineLearningModel_h
23 
24 #include "itkLightObject.h"
26 
27 // Quiet a deprecation warning
28 #define BOOST_BIND_GLOBAL_PLACEHOLDERS
29 
30 #if defined(__GNUC__) || defined(__clang__)
31 #pragma GCC diagnostic push
32 
33 #if (defined (__GNUC__) && (__GNUC__ >= 9)) || (defined (__clang__) && (__clang_major__ >= 10))
34 #pragma GCC diagnostic ignored "-Wdeprecated-copy"
35 #endif
36 #pragma GCC diagnostic ignored "-Wshadow"
37 #pragma GCC diagnostic ignored "-Wunused-parameter"
38 #pragma GCC diagnostic ignored "-Woverloaded-virtual"
39 #pragma GCC diagnostic ignored "-Wignored-qualifiers"
40 #pragma GCC diagnostic ignored "-Wsign-compare"
41 #pragma GCC diagnostic ignored "-Wcast-align"
42 #pragma GCC diagnostic ignored "-Wunknown-pragmas"
43 #pragma GCC diagnostic ignored "-Wmissing-field-initializers"
44 #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
45 #if defined(__clang__)
46 #pragma clang diagnostic ignored "-Wheader-guard"
47 #pragma clang diagnostic ignored "-Wexpansion-to-defined"
48 #else
49 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
50 #endif
51 #endif
52 #include <shark/Models/Classifier.h>
53 #include "otb_shark.h"
54 #include "shark/Algorithms/Trainers/RFTrainer.h"
55 #if defined(__GNUC__) || defined(__clang__)
56 #pragma GCC diagnostic pop
57 #endif
58 
59 
74 namespace otb
75 {
76 template <class TInputValue, class TTargetValue>
77 class ITK_EXPORT SharkRandomForestsMachineLearningModel : public MachineLearningModel<TInputValue, TTargetValue>
78 {
79 public:
83  typedef itk::SmartPointer<Self> Pointer;
84  typedef itk::SmartPointer<const Self> ConstPointer;
85 
86  typedef typename Superclass::InputValueType InputValueType;
87  typedef typename Superclass::InputSampleType InputSampleType;
88  typedef typename Superclass::InputListSampleType InputListSampleType;
89  typedef typename Superclass::TargetValueType TargetValueType;
91  typedef typename Superclass::TargetListSampleType TargetListSampleType;
93  typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType;
94  typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
95  typedef typename Superclass::ProbaSampleType ProbaSampleType;
96  typedef typename Superclass::ProbaListSampleType ProbaListSampleType;
98  itkNewMacro(Self);
101 
103  virtual void Train() override;
104 
106  virtual void Save(const std::string& filename, const std::string& name = "") override;
107 
109  virtual void Load(const std::string& filename, const std::string& name = "") override;
110 
113 
115  virtual bool CanReadFile(const std::string&) override;
116 
118  virtual bool CanWriteFile(const std::string&) override;
120 
122  itkGetMacro(NumberOfTrees, unsigned int);
123 
125  itkSetMacro(NumberOfTrees, unsigned int);
126 
128  itkGetMacro(MTry, unsigned int);
129 
131  itkSetMacro(MTry, unsigned int);
132 
136  itkGetMacro(NodeSize, unsigned int);
137 
141  itkSetMacro(NodeSize, unsigned int);
142 
146  itkGetMacro(OobRatio, float);
147 
151  itkSetMacro(OobRatio, float);
152 
154  itkGetMacro(ComputeMargin, bool);
155 
157  itkSetMacro(ComputeMargin, bool);
158 
160  itkGetMacro(NormalizeClassLabels, bool);
161  itkSetMacro(NormalizeClassLabels, bool);
163 
164 protected:
167 
170 
172  TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType* quality = nullptr, ProbaSampleType* proba = nullptr) const override;
173 
174  void DoPredictBatch(const InputListSampleType*, const unsigned int& startIndex, const unsigned int& size, TargetListSampleType*,
175  ConfidenceListSampleType* = nullptr, ProbaListSampleType* = nullptr) const override;
176 
178  void PrintSelf(std::ostream& os, itk::Indent indent) const override;
179 
180 private:
182  void operator=(const Self&) = delete;
183 
184  shark::RFClassifier<unsigned int> m_RFModel;
185  shark::RFTrainer<unsigned int> m_RFTrainer;
186  std::vector<unsigned int> m_ClassDictionary;
188 
189  unsigned int m_NumberOfTrees;
190  unsigned int m_MTry;
191  unsigned int m_NodeSize;
192  float m_OobRatio;
194 
196  ConfidenceValueType ComputeConfidence(shark::RealVector& probas, bool computeMargin) const;
197 };
198 } // end namespace otb
199 
200 #ifndef OTB_MANUAL_INSTANTIATION
202 #endif
203 
204 #endif
MachineLearningModel is the base class for all classifier objects (SVM, KNN, Random Forests,...
itk::Statistics::ListSample< ConfidenceSampleType > ConfidenceListSampleType
itk::Statistics::ListSample< ProbaSampleType > ProbaListSampleType
itk::Statistics::ListSample< TargetSampleType > TargetListSampleType
MLMTargetTraits< TConfidenceValue >::ValueType ConfidenceValueType
itk::VariableLengthVector< double > ProbaSampleType
itk::Statistics::ListSample< InputSampleType > InputListSampleType
MLMSampleTraits< TInputValue >::SampleType InputSampleType
MLMTargetTraits< TTargetValue >::SampleType TargetSampleType
MachineLearningModel< TInputValue, TTargetValue > Superclass
~SharkRandomForestsMachineLearningModel() override=default
SharkRandomForestsMachineLearningModel(const Self &)=delete
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.