20 #ifndef otbSharkKMeansMachineLearningModel_hxx
21 #define otbSharkKMeansMachineLearningModel_hxx
29 #if defined(__GNUC__) || defined(__clang__)
30 #pragma GCC diagnostic push
32 #if (defined (__GNUC__) && (__GNUC__ >= 9)) || (defined (__clang__) && (__clang_major__ >= 10))
33 #pragma GCC diagnostic ignored "-Wdeprecated-copy"
36 #pragma GCC diagnostic ignored "-Wshadow"
37 #pragma GCC diagnostic ignored "-Wunused-parameter"
38 #pragma GCC diagnostic ignored "-Woverloaded-virtual"
39 #pragma GCC diagnostic ignored "-Wignored-qualifiers"
42 #include "otb_shark.h"
43 #include "otbSharkUtils.h"
44 #include "shark/Algorithms/KMeans.h"
45 #include "shark/Models/Clustering/HardClusteringModel.h"
46 #include "shark/Models/Clustering/SoftClusteringModel.h"
47 #include <shark/Data/Csv.h>
49 #if defined(__GNUC__) || defined(__clang__)
50 #pragma GCC diagnostic pop
56 template <
class TInputValue,
class TOutputValue>
65 template <
class TInputValue,
class TOutputValue>
71 template <
class TInputValue,
class TOutputValue>
75 std::vector<shark::RealVector> vector_data;
76 otb::Shark::ListSampleToSharkVector(this->GetInputListSample(), vector_data);
77 shark::Data<shark::RealVector> data = shark::createDataFromRange(vector_data);
81 shark::kMeans(data, m_K, m_Centroids, m_MaximumNumberOfIterations);
82 m_ClusteringModel = std::make_shared<ClusteringModelType>(&m_Centroids);
85 template <
class TInputValue,
class TOutputValue>
89 shark::RealVector data(value.Size());
90 for (
size_t i = 0; i < value.Size(); i++)
92 data.push_back(value[i]);
96 if (quality !=
nullptr)
102 if (proba !=
nullptr)
104 if (!this->m_ProbaIndex)
106 itkExceptionMacro(
"Probability per class not available for this classifier !");
111 target[0] =
static_cast<TOutputValue
>(predictedValue);
115 template <
class TInputValue,
class TOutputValue>
122 assert(input !=
nullptr);
123 assert(targets !=
nullptr);
126 assert(input->Size() == targets->Size() &&
"Input sample list and target label list do not have the same size.");
127 assert(((quality ==
nullptr) || (quality->Size() == input->Size())) &&
128 "Quality samples list is not null and does not have the same size as input samples list");
129 if (startIndex + size > input->Size())
131 itkExceptionMacro(<<
"requested range [" << startIndex <<
", " << startIndex + size <<
"[ partially outside input sample list range.[0," << input->Size()
136 std::vector<shark::RealVector> features;
137 otb::Shark::ListSampleRangeToSharkVector(input, features, startIndex, size);
138 shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features);
140 shark::Data<ClusteringOutputType> clusters;
143 clusters = (*m_ClusteringModel)(inputSamples);
148 "Failed to run clustering classification. "
149 "The number of features of input samples and the model could differ.");
152 unsigned int id = startIndex;
153 for (
const auto& p : clusters.elements())
156 target[0] =
static_cast<TOutputValue
>(p);
157 targets->SetMeasurementVector(
id, target);
162 if (quality !=
nullptr)
164 for (
unsigned int qid = startIndex; qid < startIndex + size; ++qid)
169 if (proba !=
nullptr && !this->m_ProbaIndex)
171 itkExceptionMacro(
"Probability per class not available for this classifier !");
176 template <
class TInputValue,
class TOutputValue>
179 std::ofstream ofs(filename);
182 itkExceptionMacro(<<
"Error opening " << filename.c_str());
184 ofs <<
"#" << m_ClusteringModel->name() << std::endl;
185 shark::TextOutArchive oa(ofs);
186 m_ClusteringModel->save(oa, 1);
189 template <
class TInputValue,
class TOutputValue>
193 std::ifstream ifs(filename);
198 std::getline(ifs, line);
199 m_CanRead = line.find(m_ClusteringModel->name()) != std::string::npos;
205 shark::TextInArchive ia(ifs);
206 m_ClusteringModel->load(ia, 0);
210 template <
class TInputValue,
class TOutputValue>
225 template <
class TInputValue,
class TOutputValue>
231 template <
class TInputValue,
class TOutputValue>
234 shark::exportCSV(m_Centroids.centroids(), filename,
' ');
237 template <
class TInputValue,
class TOutputValue>
241 Superclass::PrintSelf(os, indent);