OTB  10.0.0
Orfeo Toolbox
otbMachineLearningModel.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2024 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbMachineLearningModel_hxx
22 #define otbMachineLearningModel_hxx
23 
24 #ifdef _OPENMP
25 #include <omp.h>
26 #endif
27 
29 #include "itkMultiThreaderBase.h"
30 
31 namespace otb
32 {
33 
34 template <class TInputValue, class TOutputValue, class TConfidenceValue>
36  : m_RegressionMode(false),
37  m_IsRegressionSupported(false),
38  m_ConfidenceIndex(false),
39  m_ProbaIndex(false),
40  m_IsDoPredictBatchMultiThreaded(false),
41  m_Dimension(0)
42 {
43 }
44 
45 template <class TInputValue, class TOutputValue, class TConfidenceValue>
47 {
48  if (flag && !m_IsRegressionSupported)
49  {
50  itkGenericExceptionMacro(<< "Regression mode not implemented.");
51  }
52  if (m_RegressionMode != flag)
53  {
54  m_RegressionMode = flag;
55  this->Modified();
56  }
57 }
58 
59 template <class TInputValue, class TOutputValue, class TConfidenceValue>
62  ProbaSampleType* proba) const
63 {
64  // Call protected specialization entry point
65  return this->DoPredict(input, quality, proba);
66 }
67 
68 
69 template <class TInputValue, class TOutputValue, class TConfidenceValue>
72  ProbaListSampleType* proba) const
73 {
74  // std::cout << "Enter batch predict" << std::endl;
75  typename TargetListSampleType::Pointer targets = TargetListSampleType::New();
76  targets->Resize(input->Size());
77 
78  if (quality != nullptr)
79  {
80  quality->Clear();
81  quality->Resize(input->Size());
82  }
83  if (proba != ITK_NULLPTR)
84  {
85  proba->Clear();
86  proba->Resize(input->Size());
87  }
88  if (m_IsDoPredictBatchMultiThreaded)
89  {
90  // Simply calls DoPredictBatch
91  this->DoPredictBatch(input, 0, input->Size(), targets, quality, proba);
92  return targets;
93  }
94  else
95  {
96 #ifdef _OPENMP
97  // OpenMP threading here
98  unsigned int nb_threads(0), threadId(0), nb_batches(0);
99 
100 #pragma omp parallel shared(nb_threads, nb_batches) private(threadId)
101  {
102  // Get number of threads configured with ITK
103  omp_set_num_threads(itk::MultiThreaderBase::GetGlobalDefaultNumberOfThreads());
104  nb_threads = omp_get_num_threads();
105  threadId = omp_get_thread_num();
106  nb_batches = std::min(nb_threads,(unsigned int)input->Size());
107  // Ensure that we do not spawn unnecessary threads
108  if(threadId<nb_batches)
109  {
110  unsigned int batch_size = ((unsigned int)input->Size() / nb_batches);
111  unsigned int batch_start = threadId * batch_size;
112  if (threadId == nb_threads - 1)
113  {
114  batch_size += input->Size() % nb_batches;
115  }
116 
117  this->DoPredictBatch(input, batch_start, batch_size, targets, quality, proba);
118  }
119  }
120 #else
121  this->DoPredictBatch(input, 0, input->Size(), targets, quality, proba);
122 #endif
123  return targets;
124  }
125 }
126 
127 
128 template <class TInputValue, class TOutputValue, class TConfidenceValue>
130  const unsigned int& size, TargetListSampleType* targets,
131  ConfidenceListSampleType* quality, ProbaListSampleType* proba) const
132 {
133  assert(input != nullptr);
134  assert(targets != nullptr);
135 
136  assert(input->Size() == targets->Size() && "Input sample list and target label list do not have the same size.");
137  assert(((quality == nullptr) || (quality->Size() == input->Size())) &&
138  "Quality samples list is not null and does not have the same size as input samples list");
139  assert(((proba == nullptr) || (input->Size() == proba->Size())) && "Proba sample list and target label list do not have the same size.");
140 
141  if (startIndex + size > input->Size())
142  {
143  itkExceptionMacro(<< "requested range [" << startIndex << ", " << startIndex + size << "[ partially outside input sample list range.[0," << input->Size()
144  << "[");
145  }
146 
147  if (proba != nullptr)
148  {
149  for (unsigned int id = startIndex; id < startIndex + size; ++id)
150  {
151  ProbaSampleType prob;
152  ConfidenceValueType confidence = 0;
153  const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(id), &confidence, &prob);
154  quality->SetMeasurementVector(id, confidence);
155  proba->SetMeasurementVector(id, prob);
156  targets->SetMeasurementVector(id, target);
157  }
158  }
159  else if (quality != ITK_NULLPTR)
160  {
161  for (unsigned int id = startIndex; id < startIndex + size; ++id)
162  {
163  ConfidenceValueType confidence = 0;
164  const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(id), &confidence);
165  quality->SetMeasurementVector(id, confidence);
166  targets->SetMeasurementVector(id, target);
167  }
168  }
169  else
170  {
171  for (unsigned int id = startIndex; id < startIndex + size; ++id)
172  {
173  const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(id));
174  targets->SetMeasurementVector(id, target);
175  }
176  }
177 }
178 
179 template <class TInputValue, class TOutputValue, class TConfidenceValue>
181 {
182  // Call superclass implementation
183  Superclass::PrintSelf(os, indent);
184 }
185 }
186 
187 #endif
MachineLearningModel is the base class for all classifier objects (SVM, KNN, Random Forests,...
itk::Statistics::ListSample< ConfidenceSampleType > ConfidenceListSampleType
itk::Statistics::ListSample< ProbaSampleType > ProbaListSampleType
void PrintSelf(std::ostream &os, itk::Indent indent) const override
TargetListSampleType::Pointer PredictBatch(const InputListSampleType *input, ConfidenceListSampleType *quality=nullptr, ProbaListSampleType *proba=nullptr) const
TargetSampleType Predict(const InputSampleType &input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const
itk::Statistics::ListSample< TargetSampleType > TargetListSampleType
MLMTargetTraits< TConfidenceValue >::ValueType ConfidenceValueType
itk::VariableLengthVector< double > ProbaSampleType
virtual void DoPredictBatch(const InputListSampleType *input, const unsigned int &startIndex, const unsigned int &size, TargetListSampleType *target, ConfidenceListSampleType *quality=nullptr, ProbaListSampleType *proba=nullptr) const
itk::Statistics::ListSample< InputSampleType > InputListSampleType
MLMSampleTraits< TInputValue >::SampleType InputSampleType
MLMTargetTraits< TTargetValue >::SampleType TargetSampleType
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.