22 #ifndef otbStreamingInnerProductVectorImageFilter_hxx
23 #define otbStreamingInnerProductVectorImageFilter_hxx
26 #include "itkImageRegionIterator.h"
27 #include "itkImageRegionConstIteratorWithIndex.h"
28 #include "itkNumericTraits.h"
29 #include "itkProgressReporter.h"
35 template <
class TInputImage>
46 typename ImageType::Pointer output1 =
static_cast<ImageType*
>(this->MakeOutput(0).GetPointer());
47 this->itk::ProcessObject::SetNthOutput(0, output1.GetPointer());
48 typename MatrixObjectType::Pointer output2 =
static_cast<MatrixObjectType*
>(this->MakeOutput(1).GetPointer());
49 this->itk::ProcessObject::SetNthOutput(1, output2.GetPointer());
54 template <
class TInputImage>
60 return static_cast<itk::DataObject*
>(TInputImage::New().GetPointer());
63 return static_cast<itk::DataObject*
>(MatrixObjectType::New().GetPointer());
67 return static_cast<itk::DataObject*
>(TInputImage::New().GetPointer());
72 template <
class TInputImage>
75 return static_cast<MatrixObjectType*
>(this->itk::ProcessObject::GetOutput(1));
78 template <
class TInputImage>
82 return static_cast<const MatrixObjectType*
>(this->itk::ProcessObject::GetOutput(1));
85 template <
class TInputImage>
88 Superclass::GenerateOutputInformation();
91 this->GetOutput()->CopyInformation(this->GetInput());
92 this->GetOutput()->SetLargestPossibleRegion(this->GetInput()->GetLargestPossibleRegion());
94 if (this->GetOutput()->GetRequestedRegion().GetNumberOfPixels() == 0)
96 this->GetOutput()->SetRequestedRegion(this->GetOutput()->GetLargestPossibleRegion());
101 template <
class TInputImage>
111 template <
class TInputImage>
114 TInputImage* inputPtr =
const_cast<TInputImage*
>(this->GetInput());
115 inputPtr->UpdateOutputInformation();
117 if (this->GetOutput()->GetRequestedRegion().GetNumberOfPixels() == 0)
119 this->GetOutput()->SetRequestedRegion(this->GetOutput()->GetLargestPossibleRegion());
122 unsigned int numberOfThreads = this->GetNumberOfThreads();
123 unsigned int numberOfTrainingImages = inputPtr->GetNumberOfComponentsPerPixel();
126 tempMatrix.set_size(numberOfTrainingImages, numberOfTrainingImages);
131 initMatrix.set_size(numberOfTrainingImages, numberOfTrainingImages);
133 this->GetInnerProductOutput()->Set(initMatrix);
136 template <
class TInputImage>
140 TInputImage* inputPtr =
const_cast<TInputImage*
>(this->GetInput());
141 unsigned int numberOfTrainingImages = inputPtr->GetNumberOfComponentsPerPixel();
142 unsigned int numberOfThreads = this->GetNumberOfThreads();
144 innerProduct.set_size(numberOfTrainingImages, numberOfTrainingImages);
145 innerProduct.fill(0);
148 for (
unsigned int thread = 0; thread < numberOfThreads; thread++)
150 innerProduct += m_ThreadInnerProduct[thread];
156 for (
unsigned int band_x = 0; band_x < (numberOfTrainingImages - 1); ++band_x)
158 for (
unsigned int band_y = band_x + 1; band_y < numberOfTrainingImages; ++band_y)
160 innerProduct[band_x][band_y] = innerProduct[band_y][band_x];
164 if ((numberOfTrainingImages - 1) != 0)
166 innerProduct /= (numberOfTrainingImages - 1);
170 innerProduct.fill(0);
174 this->GetInnerProductOutput()->Set(innerProduct);
177 template <
class TInputImage>
185 itk::ProgressReporter progress(
this, threadId, outputRegionForThread.GetNumberOfPixels());
186 unsigned int numberOfTrainingImages = inputPtr->GetNumberOfComponentsPerPixel();
189 itk::ImageRegionConstIterator<TInputImage> it(inputPtr, outputRegionForThread);
190 if (m_CenterData ==
true)
194 while (!it.IsAtEnd())
198 for (
unsigned int i = 0; i < vectorValue.GetSize(); ++i)
200 mean +=
static_cast<double>(vectorValue[i]);
202 mean /=
static_cast<double>(vectorValue.GetSize());
205 for (
unsigned int band_x = 0; band_x < numberOfTrainingImages; ++band_x)
207 for (
unsigned int band_y = 0; band_y <= band_x; ++band_y)
209 m_ThreadInnerProduct[threadId][band_x][band_y] +=
210 (
static_cast<double>(vectorValue[band_x]) -
mean) * (
static_cast<double>(vectorValue[band_y]) -
mean);
214 progress.CompletedPixel();
221 while (!it.IsAtEnd())
225 for (
unsigned int band_x = 0; band_x < numberOfTrainingImages; ++band_x)
227 for (
unsigned int band_y = 0; band_y <= band_x; ++band_y)
229 m_ThreadInnerProduct[threadId][band_x][band_y] += (
static_cast<double>(vectorValue[band_x])) * (
static_cast<double>(vectorValue[band_y]));
233 progress.CompletedPixel();
238 template <
class TImage>
241 Superclass::PrintSelf(os, indent);
242 os << indent <<
"m_CenterData: " << m_CenterData << std::endl;
243 os << indent <<
"InnerProduct: " << this->GetInnerProductOutput()->Get() << std::endl;