21 #ifndef otbMachineLearningModel_hxx
22 #define otbMachineLearningModel_hxx
30 #include "itkMultiThreader.h"
35 template <
class TInputValue,
class TOutputValue,
class TConf
idenceValue>
37 : m_RegressionMode(false),
38 m_IsRegressionSupported(false),
39 m_ConfidenceIndex(false),
41 m_IsDoPredictBatchMultiThreaded(false),
46 template <
class TInputValue,
class TOutputValue,
class TConf
idenceValue>
49 if (flag && !m_IsRegressionSupported)
51 itkGenericExceptionMacro(<<
"Regression mode not implemented.");
53 if (m_RegressionMode != flag)
55 m_RegressionMode = flag;
60 template <
class TInputValue,
class TOutputValue,
class TConf
idenceValue>
66 return this->DoPredict(input, quality, proba);
70 template <
class TInputValue,
class TOutputValue,
class TConf
idenceValue>
76 typename TargetListSampleType::Pointer targets = TargetListSampleType::New();
77 targets->Resize(input->Size());
79 if (quality !=
nullptr)
82 quality->Resize(input->Size());
84 if (proba != ITK_NULLPTR)
87 proba->Resize(input->Size());
89 if (m_IsDoPredictBatchMultiThreaded)
92 this->DoPredictBatch(input, 0, input->Size(), targets, quality, proba);
99 unsigned int nb_threads(0), threadId(0), nb_batches(0);
101 #pragma omp parallel shared(nb_threads, nb_batches) private(threadId)
104 omp_set_num_threads(itk::MultiThreader::GetGlobalDefaultNumberOfThreads());
105 nb_threads = omp_get_num_threads();
106 threadId = omp_get_thread_num();
107 nb_batches = std::min(nb_threads, (
unsigned int)input->Size());
109 if (threadId < nb_batches)
111 unsigned int batch_size = ((
unsigned int)input->Size() / nb_batches);
112 unsigned int batch_start = threadId * batch_size;
113 if (threadId == nb_threads - 1)
115 batch_size += input->Size() % nb_batches;
118 this->DoPredictBatch(input, batch_start, batch_size, targets, quality, proba);
122 this->DoPredictBatch(input, 0, input->Size(), targets, quality, proba);
129 template <
class TInputValue,
class TOutputValue,
class TConf
idenceValue>
134 assert(input !=
nullptr);
135 assert(targets !=
nullptr);
137 assert(input->Size() == targets->Size() &&
"Input sample list and target label list do not have the same size.");
138 assert(((quality ==
nullptr) || (quality->Size() == input->Size())) &&
139 "Quality samples list is not null and does not have the same size as input samples list");
140 assert(((proba ==
nullptr) || (input->Size() == proba->Size())) &&
"Proba sample list and target label list do not have the same size.");
142 if (startIndex + size > input->Size())
144 itkExceptionMacro(<<
"requested range [" << startIndex <<
", " << startIndex + size <<
"[ partially outside input sample list range.[0," << input->Size()
148 if (proba !=
nullptr)
150 for (
unsigned int id = startIndex;
id < startIndex + size; ++id)
154 const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(
id), &confidence, &prob);
155 quality->SetMeasurementVector(
id, confidence);
156 proba->SetMeasurementVector(
id, prob);
157 targets->SetMeasurementVector(
id, target);
160 else if (quality != ITK_NULLPTR)
162 for (
unsigned int id = startIndex;
id < startIndex + size; ++id)
165 const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(
id), &confidence);
166 quality->SetMeasurementVector(
id, confidence);
167 targets->SetMeasurementVector(
id, target);
172 for (
unsigned int id = startIndex;
id < startIndex + size; ++id)
174 const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(
id));
175 targets->SetMeasurementVector(
id, target);
180 template <
class TInputValue,
class TOutputValue,
class TConf
idenceValue>
184 Superclass::PrintSelf(os, indent);