21 #ifndef otbFastICAInternalOptimizerVectorImageFilter_hxx
22 #define otbFastICAInternalOptimizerVectorImageFilter_hxx
26 #include <itkImageRegionIterator.h>
28 #include <vnl/vnl_math.h>
29 #include <vnl/vnl_matrix.h>
34 template <
class TInputImage,
class TOutputImage>
37 this->SetNumberOfRequiredInputs(2);
39 m_CurrentBandForLoop = 0;
43 m_NonLinearity = [](
double x) {
return std::tanh(x); };
44 m_NonLinearityDerivative = [](
double x) {
return 1 - std::pow(std::tanh(x), 2.); };
46 m_TransformFilter = TransformFilterType::New();
49 template <
class TInputImage,
class TOutputImage>
52 Superclass::GenerateOutputInformation();
54 this->GetOutput()->SetNumberOfComponentsPerPixel(this->GetInput(0)->GetNumberOfComponentsPerPixel());
57 template <
class TInputImage,
class TOutputImage>
62 throw itk::ExceptionObject(__FILE__, __LINE__,
"Give the initial W matrix", ITK_LOCATION);
65 m_BetaVector.resize(this->GetNumberOfThreads());
66 m_DenVector.resize(this->GetNumberOfThreads());
67 m_NbSamples.resize(this->GetNumberOfThreads());
69 std::fill(m_BetaVector.begin(), m_BetaVector.end(), 0.);
70 std::fill(m_DenVector.begin(), m_DenVector.end(), 0.);
71 std::fill(m_NbSamples.begin(), m_NbSamples.end(), 0.);
74 template <
class TInputImage,
class TOutputImage>
76 itk::ThreadIdType threadId)
79 this->CallCopyOutputRegionToInputRegion(inputRegion, outputRegionForThread);
81 itk::ImageRegionConstIterator<InputImageType> input0It(this->GetInput(0), inputRegion);
82 itk::ImageRegionConstIterator<InputImageType> input1It(this->GetInput(1), inputRegion);
83 itk::ImageRegionIterator<OutputImageType> outputIt(this->GetOutput(), outputRegionForThread);
85 unsigned int nbBands = this->GetInput(0)->GetNumberOfComponentsPerPixel();
94 while (!input0It.IsAtEnd() && !input1It.IsAtEnd() && !outputIt.IsAtEnd())
96 double x =
static_cast<double>(input1It.Get()[GetCurrentBandForLoop()]);
97 double g_x = m_NonLinearity(x);
99 double x_g_x = x * g_x;
102 double gp = m_NonLinearityDerivative(x);
107 typename OutputImageType::PixelType z(nbBands);
108 for (
unsigned int bd = 0; bd < nbBands; bd++)
109 z[bd] = g_x * input0It.Get()[bd];
117 m_BetaVector[threadId] += beta;
118 m_DenVector[threadId] += den;
119 m_NbSamples[threadId] += nbSample;
122 template <
class TInputImage,
class TOutputImage>
129 for (itk::ThreadIdType i = 0; i < this->GetNumberOfThreads(); ++i)
131 beta += m_BetaVector[i];
132 den += m_DenVector[i];
133 nbSample += m_NbSamples[i];
136 m_Beta = beta / nbSample;
137 m_Den = den / nbSample - m_Beta;