21 #ifndef otbMachineLearningModel_hxx
22 #define otbMachineLearningModel_hxx
29 #include "itkMultiThreaderBase.h"
34 template <
class TInputValue,
class TOutputValue,
class TConf
idenceValue>
36 : m_RegressionMode(false),
37 m_IsRegressionSupported(false),
38 m_ConfidenceIndex(false),
40 m_IsDoPredictBatchMultiThreaded(false),
45 template <
class TInputValue,
class TOutputValue,
class TConf
idenceValue>
48 if (flag && !m_IsRegressionSupported)
50 itkGenericExceptionMacro(<<
"Regression mode not implemented.");
52 if (m_RegressionMode != flag)
54 m_RegressionMode = flag;
59 template <
class TInputValue,
class TOutputValue,
class TConf
idenceValue>
65 return this->DoPredict(input, quality, proba);
69 template <
class TInputValue,
class TOutputValue,
class TConf
idenceValue>
75 typename TargetListSampleType::Pointer targets = TargetListSampleType::New();
76 targets->Resize(input->Size());
78 if (quality !=
nullptr)
81 quality->Resize(input->Size());
83 if (proba != ITK_NULLPTR)
86 proba->Resize(input->Size());
88 if (m_IsDoPredictBatchMultiThreaded)
91 this->DoPredictBatch(input, 0, input->Size(), targets, quality, proba);
98 unsigned int nb_threads(0), threadId(0), nb_batches(0);
100 #pragma omp parallel shared(nb_threads, nb_batches) private(threadId)
103 omp_set_num_threads(itk::MultiThreaderBase::GetGlobalDefaultNumberOfThreads());
104 nb_threads = omp_get_num_threads();
105 threadId = omp_get_thread_num();
106 nb_batches = std::min(nb_threads,(
unsigned int)input->Size());
108 if(threadId<nb_batches)
110 unsigned int batch_size = ((
unsigned int)input->Size() / nb_batches);
111 unsigned int batch_start = threadId * batch_size;
112 if (threadId == nb_threads - 1)
114 batch_size += input->Size() % nb_batches;
117 this->DoPredictBatch(input, batch_start, batch_size, targets, quality, proba);
121 this->DoPredictBatch(input, 0, input->Size(), targets, quality, proba);
128 template <
class TInputValue,
class TOutputValue,
class TConf
idenceValue>
133 assert(input !=
nullptr);
134 assert(targets !=
nullptr);
136 assert(input->Size() == targets->Size() &&
"Input sample list and target label list do not have the same size.");
137 assert(((quality ==
nullptr) || (quality->Size() == input->Size())) &&
138 "Quality samples list is not null and does not have the same size as input samples list");
139 assert(((proba ==
nullptr) || (input->Size() == proba->Size())) &&
"Proba sample list and target label list do not have the same size.");
141 if (startIndex + size > input->Size())
143 itkExceptionMacro(<<
"requested range [" << startIndex <<
", " << startIndex + size <<
"[ partially outside input sample list range.[0," << input->Size()
147 if (proba !=
nullptr)
149 for (
unsigned int id = startIndex;
id < startIndex + size; ++id)
151 ProbaSampleType prob;
152 ConfidenceValueType confidence = 0;
153 const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(
id), &confidence, &prob);
154 quality->SetMeasurementVector(
id, confidence);
155 proba->SetMeasurementVector(
id, prob);
156 targets->SetMeasurementVector(
id, target);
159 else if (quality != ITK_NULLPTR)
161 for (
unsigned int id = startIndex;
id < startIndex + size; ++id)
163 ConfidenceValueType confidence = 0;
164 const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(
id), &confidence);
165 quality->SetMeasurementVector(
id, confidence);
166 targets->SetMeasurementVector(
id, target);
171 for (
unsigned int id = startIndex;
id < startIndex + size; ++id)
173 const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(
id));
174 targets->SetMeasurementVector(
id, target);
179 template <
class TInputValue,
class TOutputValue,
class TConf
idenceValue>
183 Superclass::PrintSelf(os, indent);
MachineLearningModel is the base class for all classifier objects (SVM, KNN, Random Forests,...
itk::Statistics::ListSample< ConfidenceSampleType > ConfidenceListSampleType
itk::Statistics::ListSample< ProbaSampleType > ProbaListSampleType
void SetRegressionMode(bool flag)
void PrintSelf(std::ostream &os, itk::Indent indent) const override
TargetListSampleType::Pointer PredictBatch(const InputListSampleType *input, ConfidenceListSampleType *quality=nullptr, ProbaListSampleType *proba=nullptr) const
TargetSampleType Predict(const InputSampleType &input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const
itk::Statistics::ListSample< TargetSampleType > TargetListSampleType
MLMTargetTraits< TConfidenceValue >::ValueType ConfidenceValueType
itk::VariableLengthVector< double > ProbaSampleType
virtual void DoPredictBatch(const InputListSampleType *input, const unsigned int &startIndex, const unsigned int &size, TargetListSampleType *target, ConfidenceListSampleType *quality=nullptr, ProbaListSampleType *proba=nullptr) const
itk::Statistics::ListSample< InputSampleType > InputListSampleType
MLMSampleTraits< TInputValue >::SampleType InputSampleType
MLMTargetTraits< TTargetValue >::SampleType TargetSampleType
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.