21 #ifndef otbSharkRandomForestsMachineLearningModel_hxx
22 #define otbSharkRandomForestsMachineLearningModel_hxx
28 #if defined(__GNUC__) || defined(__clang__)
29 #pragma GCC diagnostic push
30 #pragma GCC diagnostic ignored "-Wshadow"
31 #pragma GCC diagnostic ignored "-Wunused-parameter"
32 #pragma GCC diagnostic ignored "-Woverloaded-virtual"
33 #pragma GCC diagnostic ignored "-Wignored-qualifiers"
35 #if defined(__GNUC__) || defined(__clang__)
36 #pragma GCC diagnostic pop
40 #include "otbSharkUtils.h"
46 template <
class TInputValue,
class TOutputValue>
49 this->m_ConfidenceIndex =
true;
50 this->m_ProbaIndex =
true;
51 this->m_IsRegressionSupported =
false;
52 this->m_IsDoPredictBatchMultiThreaded =
true;
53 this->m_NormalizeClassLabels =
true;
54 this->m_ComputeMargin =
false;
58 template <
class TInputValue,
class TOutputValue>
62 omp_set_num_threads(itk::MultiThreader::GetGlobalDefaultNumberOfThreads());
65 std::vector<shark::RealVector> features;
66 std::vector<unsigned int> class_labels;
68 Shark::ListSampleToSharkVector(this->GetInputListSample(), features);
69 Shark::ListSampleToSharkVector(this->GetTargetListSample(), class_labels);
70 if (m_NormalizeClassLabels)
72 Shark::NormalizeLabelsAndGetDictionary(class_labels, m_ClassDictionary);
74 shark::ClassificationDataset TrainSamples = shark::createLabeledDataFromRange(features, class_labels);
77 m_RFTrainer.setMTry(m_MTry);
78 m_RFTrainer.setNTrees(m_NumberOfTrees);
79 m_RFTrainer.setNodeSize(m_NodeSize);
81 m_RFTrainer.train(m_RFModel, TrainSamples);
84 template <
class TInputValue,
class TOutputValue>
88 assert(!probas.empty() &&
"probas vector is empty");
89 assert((!computeMargin || probas.size() > 1) &&
"probas size should be at least 2 if computeMargin is true");
94 std::nth_element(probas.begin(), probas.begin() + 1, probas.end(), std::greater<double>());
99 auto max_proba = *(std::max_element(probas.begin(), probas.end()));
105 template <
class TInputValue,
class TOutputValue>
110 shark::RealVector samples(value.Size());
111 for (
size_t i = 0; i < value.Size(); i++)
113 samples.push_back(value[i]);
115 if (quality !=
nullptr || proba !=
nullptr)
117 shark::RealVector probas = m_RFModel.decisionFunction()(samples);
118 if (quality !=
nullptr)
120 (*quality) = ComputeConfidence(probas, m_ComputeMargin);
122 if (proba !=
nullptr)
124 for (
size_t i = 0; i < probas.size(); i++)
127 (*proba)[i] =
static_cast<unsigned int>(probas[i] * 1000);
132 m_RFModel.eval(samples, res);
135 if (m_NormalizeClassLabels)
137 target[0] = m_ClassDictionary[
static_cast<TOutputValue
>(res)];
141 target[0] =
static_cast<TOutputValue
>(res);
146 template <
class TInputValue,
class TOutputValue>
151 assert(input !=
nullptr);
152 assert(targets !=
nullptr);
154 assert(input->Size() == targets->Size() &&
"Input sample list and target label list do not have the same size.");
155 assert(((quality ==
nullptr) || (quality->Size() == input->Size())) &&
156 "Quality samples list is not null and does not have the same size as input samples list");
157 assert(((proba ==
nullptr) || (input->Size() == proba->Size())) &&
"Proba sample list and target label list do not have the same size.");
159 if (startIndex + size > input->Size())
161 itkExceptionMacro(<<
"requested range [" << startIndex <<
", " << startIndex + size <<
"[ partially outside input sample list range.[0," << input->Size()
165 std::vector<shark::RealVector> features;
166 Shark::ListSampleRangeToSharkVector(input, features, startIndex, size);
167 shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features);
170 omp_set_num_threads(itk::MultiThreader::GetGlobalDefaultNumberOfThreads());
173 if (proba !=
nullptr || quality !=
nullptr)
175 shark::Data<shark::RealVector> probas = m_RFModel.decisionFunction()(inputSamples);
176 if (proba !=
nullptr)
178 unsigned int id = startIndex;
179 for (shark::RealVector&& p : probas.elements())
182 for (
size_t i = 0; i < p.size(); i++)
184 prob[i] = p[i] * 1000;
186 proba->SetMeasurementVector(
id, prob);
190 if (quality !=
nullptr)
192 unsigned int id = startIndex;
193 for (shark::RealVector&& p : probas.elements())
196 auto conf = ComputeConfidence(p, m_ComputeMargin);
198 quality->SetMeasurementVector(
id, confidence);
204 auto prediction = m_RFModel(inputSamples);
205 unsigned int id = startIndex;
206 for (
const auto& p : prediction.elements())
209 if (m_NormalizeClassLabels)
211 target[0] = m_ClassDictionary[
static_cast<TOutputValue
>(p)];
215 target[0] =
static_cast<TOutputValue
>(p);
217 targets->SetMeasurementVector(
id, target);
222 template <
class TInputValue,
class TOutputValue>
225 std::ofstream ofs(filename);
228 itkExceptionMacro(<<
"Error opening " << filename.c_str());
231 ofs <<
"#" << m_RFModel.name();
232 if (m_NormalizeClassLabels)
233 ofs <<
" with_dictionary";
235 if (m_NormalizeClassLabels)
237 ofs << m_ClassDictionary.size() <<
" ";
238 for (
const auto& l : m_ClassDictionary)
244 shark::TextOutArchive oa(ofs);
245 m_RFModel.save(oa, 0);
248 template <
class TInputValue,
class TOutputValue>
251 std::ifstream ifs(filename);
257 if (line.at(0) ==
'#')
259 if (line.find(m_RFModel.name()) == std::string::npos)
260 itkExceptionMacro(
"The model file : " + filename +
" cannot be read.");
261 if (line.find(
"with_dictionary") == std::string::npos)
263 m_NormalizeClassLabels =
false;
270 ifs.seekg(0, std::ios::beg);
272 if (m_NormalizeClassLabels)
276 m_ClassDictionary.resize(nbLabels);
277 for (
size_t i = 0; i < nbLabels; ++i)
281 m_ClassDictionary[i] = label;
284 shark::TextInArchive ia(ifs);
285 m_RFModel.load(ia, 0);
289 template <
class TInputValue,
class TOutputValue>
304 template <
class TInputValue,
class TOutputValue>
310 template <
class TInputValue,
class TOutputValue>
314 Superclass::PrintSelf(os, indent);