21 #ifndef otbFastICAImageFilter_hxx
22 #define otbFastICAImageFilter_hxx
27 #include "itkNumericTraits.h"
28 #include "itkProgressReporter.h"
30 #include <vnl/vnl_matrix.h>
31 #include <vnl/algo/vnl_matrix_inverse.h>
32 #include <vnl/algo/vnl_generalized_eigensystem.h>
37 template <
class TInputImage,
class TOutputImage, Transform::TransformDirection TDirectionOfTransformation>
40 this->SetNumberOfRequiredInputs(1);
42 m_NumberOfPrincipalComponentsRequired = 0;
44 m_GivenTransformationMatrix =
false;
45 m_IsTransformationForward =
true;
47 m_NumberOfIterations = 50;
48 m_ConvergenceThreshold = 1E-4;
50 m_NonLinearity = [](
double x) {
return std::tanh(x); };
51 m_NonLinearityDerivative = [](
double x) {
return 1 - std::pow(std::tanh(x), 2.); };
55 m_PCAFilter = PCAFilterType::New();
56 m_PCAFilter->SetUseNormalization(
true);
57 m_PCAFilter->SetUseVarianceForNormalization(
false);
59 m_TransformFilter = TransformFilterType::New();
62 template <
class TInputImage,
class TOutputImage, Transform::TransformDirection TDirectionOfTransformation>
66 Superclass::GenerateOutputInformation();
68 switch (
static_cast<int>(DirectionOfTransformation))
72 if (m_NumberOfPrincipalComponentsRequired == 0 || m_NumberOfPrincipalComponentsRequired > this->GetInput()->GetNumberOfComponentsPerPixel())
74 m_NumberOfPrincipalComponentsRequired = this->GetInput()->GetNumberOfComponentsPerPixel();
76 m_PCAFilter->SetNumberOfPrincipalComponentsRequired(m_NumberOfPrincipalComponentsRequired);
77 this->GetOutput()->SetNumberOfComponentsPerPixel(m_NumberOfPrincipalComponentsRequired);
82 unsigned int theOutputDimension = 0;
83 if (m_GivenTransformationMatrix)
85 const auto & pcaMatrix = m_PCAFilter->GetTransformationMatrix();
86 theOutputDimension = pcaMatrix.Rows() >= pcaMatrix.Cols() ? pcaMatrix.Rows() : pcaMatrix.Cols();
90 throw itk::ExceptionObject(__FILE__, __LINE__,
"Mixture matrix is required to know the output size", ITK_LOCATION);
93 this->GetOutput()->SetNumberOfComponentsPerPixel(theOutputDimension);
98 throw itk::ExceptionObject(__FILE__, __LINE__,
"Class should be templeted with FORWARD or INVERSE only...", ITK_LOCATION);
101 switch (
static_cast<int>(DirectionOfTransformation))
105 ForwardGenerateOutputInformation();
110 ReverseGenerateOutputInformation();
116 template <
class TInputImage,
class TOutputImage, Transform::TransformDirection TDirectionOfTransformation>
119 typename InputImageType::Pointer inputImgPtr =
const_cast<InputImageType*
>(this->GetInput());
121 m_PCAFilter->SetInput(inputImgPtr);
122 m_PCAFilter->GetOutput()->UpdateOutputInformation();
124 if (!m_GivenTransformationMatrix)
126 GenerateTransformationMatrix();
128 else if (!m_IsTransformationForward)
131 m_IsTransformationForward =
true;
132 vnl_svd<MatrixElementType> invertor(m_TransformationMatrix.GetVnlMatrix());
133 m_TransformationMatrix = invertor.pinverse();
136 if (m_TransformationMatrix.GetVnlMatrix().empty())
138 throw itk::ExceptionObject(__FILE__, __LINE__,
"Empty transformation matrix", ITK_LOCATION);
141 m_TransformFilter->SetInput(m_PCAFilter->GetOutput());
142 m_TransformFilter->SetMatrix(m_TransformationMatrix.GetVnlMatrix());
145 template <
class TInputImage,
class TOutputImage, Transform::TransformDirection TDirectionOfTransformation>
148 if (!m_GivenTransformationMatrix)
150 throw itk::ExceptionObject(__FILE__, __LINE__,
"No Transformation matrix given", ITK_LOCATION);
153 if (m_TransformationMatrix.GetVnlMatrix().empty())
155 throw itk::ExceptionObject(__FILE__, __LINE__,
"Empty transformation matrix", ITK_LOCATION);
158 if (m_IsTransformationForward)
161 m_IsTransformationForward =
false;
162 vnl_svd<MatrixElementType> invertor(m_TransformationMatrix.GetVnlMatrix());
163 m_TransformationMatrix = invertor.pinverse();
166 m_TransformFilter->SetInput(this->GetInput());
167 m_TransformFilter->SetMatrix(m_TransformationMatrix.GetVnlMatrix());
174 m_PCAFilter->SetInput(m_TransformFilter->GetOutput());
178 template <
class TInputImage,
class TOutputImage, Transform::TransformDirection TDirectionOfTransformation>
181 switch (
static_cast<int>(DirectionOfTransformation))
184 return ForwardGenerateData();
186 return ReverseGenerateData();
188 throw itk::ExceptionObject(__FILE__, __LINE__,
"Class should be templated with FORWARD or INVERSE only...", ITK_LOCATION);
192 template <
class TInputImage,
class TOutputImage, Transform::TransformDirection TDirectionOfTransformation>
195 m_TransformFilter->GraftOutput(this->GetOutput());
196 m_TransformFilter->Update();
198 this->GraftOutput(m_TransformFilter->GetOutput());
201 template <
class TInputImage,
class TOutputImage, Transform::TransformDirection TDirectionOfTransformation>
204 m_PCAFilter->GraftOutput(this->GetOutput());
205 m_PCAFilter->Update();
206 this->GraftOutput(m_PCAFilter->GetOutput());
209 template <
class TInputImage,
class TOutputImage, Transform::TransformDirection TDirectionOfTransformation>
212 itk::ProgressReporter reporter(
this, 0, GetNumberOfIterations(), GetNumberOfIterations());
214 double convergence = itk::NumericTraits<double>::max();
215 unsigned int iteration = 0;
217 const unsigned int size = this->GetNumberOfPrincipalComponentsRequired();
222 while (iteration++ < GetNumberOfIterations() && convergence > GetConvergenceThreshold())
226 typename InputImageType::Pointer img =
const_cast<InputImageType*
>(m_PCAFilter->GetOutput());
228 if (!W.is_identity())
230 transformer->SetInput(GetPCAFilter()->GetOutput());
231 transformer->SetMatrix(W);
232 transformer->Update();
236 for (
unsigned int band = 0; band < size; band++)
238 otbMsgDebugMacro(<<
"Iteration " << iteration <<
", bande " << band <<
", convergence " << convergence);
241 optimizer->SetInput(0, m_PCAFilter->GetOutput());
242 optimizer->SetInput(1, img);
244 optimizer->SetNonLinearity(this->GetNonLinearity(), this->GetNonLinearityDerivative());
245 optimizer->SetCurrentBandForLoop(band);
248 estimator->SetInput(optimizer->GetOutput());
254 optimizer->Synthetize();
257 for (
unsigned int bd = 0; bd < size; bd++)
259 W(bd, band) -= m_Mu * (estimator->GetMean()[bd] - optimizer->GetBeta() * W(bd, band)) / optimizer->GetDen();
260 norm += std::pow(W(bd, band), 2.);
262 for (
unsigned int bd = 0; bd < size; bd++)
263 W(bd, band) /= std::sqrt(norm);
268 vnl_svd<MatrixElementType> solver(W_tmp);
270 for (
unsigned int i = 0; i < valP.rows(); ++i)
271 valP(i, i) = 1. / std::sqrt(
static_cast<double>(valP(i, i)));
273 W_tmp = transf * valP * transf.transpose();
278 for (
unsigned int i = 0; i < W.rows(); ++i)
279 for (
unsigned int j = 0; j < W.cols(); ++j)
280 convergence += std::abs(W(i, j) - W_old(i, j));
282 reporter.CompletedPixel();
285 this->m_TransformationMatrix = W;
287 otbMsgDebugMacro(<<
"Final convergence " << convergence <<
" after " << iteration <<
" iterations");