20 #ifndef otbPCAModel_hxx
21 #define otbPCAModel_hxx
27 #if defined(__GNUC__) || defined(__clang__)
28 #pragma GCC diagnostic push
30 #if (defined (__GNUC__) && (__GNUC__ >= 9)) || (defined (__clang__) && (__clang_major__ >= 10))
31 #pragma GCC diagnostic ignored "-Wdeprecated-copy"
34 #pragma GCC diagnostic ignored "-Wshadow"
35 #pragma GCC diagnostic ignored "-Wunused-parameter"
36 #pragma GCC diagnostic ignored "-Woverloaded-virtual"
38 #include "otbSharkUtils.h"
40 #include <shark/ObjectiveFunctions/ErrorFunction.h>
41 #include <shark/Algorithms/GradientDescent/Rprop.h>
42 #include <shark/ObjectiveFunctions/Loss/SquaredLoss.h>
43 #include <shark/ObjectiveFunctions/Regularizer.h>
44 #include <shark/ObjectiveFunctions/ErrorFunction.h>
45 #if defined(__GNUC__) || defined(__clang__)
46 #pragma GCC diagnostic pop
52 template <
class TInputValue>
55 this->m_IsDoPredictBatchMultiThreaded =
true;
56 this->m_Dimension = 0;
59 template <
class TInputValue>
64 template <
class TInputValue>
67 std::vector<shark::RealVector> features;
69 Shark::ListSampleToSharkVector(this->GetInputListSample(), features);
71 shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features);
72 m_PCA.setData(inputSamples);
73 m_PCA.encoder(m_Encoder, this->m_Dimension);
74 m_PCA.decoder(m_Decoder, this->m_Dimension);
77 template <
class TInputValue>
92 template <
class TInputValue>
98 template <
class TInputValue>
101 std::ofstream ofs(filename);
102 ofs <<
"pca" << std::endl;
103 shark::TextOutArchive oa(ofs);
107 if (this->m_WriteEigenvectors ==
true)
109 std::ofstream otxt(filename +
".txt");
111 otxt <<
"Eigenvectors : " << m_PCA.eigenvectors() << std::endl;
112 otxt <<
"Eigenvalues : " << m_PCA.eigenvalues() << std::endl;
114 std::vector<shark::RealVector> features;
116 shark::SquaredLoss<shark::RealVector> loss;
117 Shark::ListSampleToSharkVector(this->GetInputListSample(), features);
118 shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features);
119 otxt <<
"Reconstruction error : " << loss.eval(inputSamples, m_Decoder(m_Encoder(inputSamples))) << std::endl;
124 template <
class TInputValue>
127 std::ifstream ifs(filename);
129 ifs.getline(encoder, 256);
130 std::string encoderstr(encoder);
132 if (encoderstr !=
"pca")
134 itkExceptionMacro(<<
"Error opening " << filename.c_str());
136 shark::TextInArchive ia(ifs);
139 if (this->m_Dimension == 0)
141 this->m_Dimension = m_Encoder.outputShape()[0];
144 auto eigenvectors = m_Encoder.matrix();
145 eigenvectors.resize(this->m_Dimension, m_Encoder.inputShape()[0]);
147 m_Encoder.setStructure(eigenvectors, m_Encoder.offset());
150 template <
class TInputValue>
154 shark::RealVector samples(value.Size());
155 for (
size_t i = 0; i < value.Size(); i++)
157 samples[i] = value[i];
160 std::vector<shark::RealVector> features;
161 features.push_back(samples);
163 shark::Data<shark::RealVector> data = shark::createDataFromRange(features);
165 data = m_Encoder(data);
167 target.SetSize(this->m_Dimension);
169 for (
unsigned int a = 0; a < this->m_Dimension; ++a)
171 target[a] = data.element(0)[a];
176 template <
class TInputValue>
180 std::vector<shark::RealVector> features;
181 Shark::ListSampleRangeToSharkVector(input, features, startIndex, size);
182 shark::Data<shark::RealVector> data = shark::createDataFromRange(features);
184 data = m_Encoder(data);
185 unsigned int id = startIndex;
186 target.SetSize(this->m_Dimension);
187 for (
const auto& p : data.elements())
189 for (
unsigned int a = 0; a < this->m_Dimension; ++a)
193 targets->SetMeasurementVector(
id, target);