OTB  10.0.0
Orfeo Toolbox
otbImageDimensionalityReductionFilter.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2024 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 #ifndef otbImageDimensionalityReductionFilter_hxx
21 #define otbImageDimensionalityReductionFilter_hxx
22 
24 #include "itkImageRegionIterator.h"
25 #include "itkProgressReporter.h"
26 
27 namespace otb
28 {
32 template <class TInputImage, class TOutputImage, class TMaskImage>
34 {
35  this->DynamicMultiThreadingOff();
36  this->SetNumberOfIndexedInputs(2);
37  this->SetNumberOfRequiredInputs(1);
39 
40  this->SetNumberOfRequiredOutputs(2);
41  this->SetNthOutput(0, TOutputImage::New());
42  this->SetNthOutput(1, ConfidenceImageType::New());
43  m_UseConfidenceMap = false;
44  m_BatchMode = true;
45 }
46 
47 template <class TInputImage, class TOutputImage, class TMaskImage>
49 {
50  this->itk::ProcessObject::SetNthInput(1, const_cast<MaskImageType*>(mask));
51 }
52 
53 template <class TInputImage, class TOutputImage, class TMaskImage>
56 {
57  if (this->GetNumberOfInputs() < 2)
58  {
59  return nullptr;
60  }
61  return static_cast<const MaskImageType*>(this->itk::ProcessObject::GetInput(1));
62 }
63 
64 template <class TInputImage, class TOutputImage, class TMaskImage>
67 {
68  if (this->GetNumberOfOutputs() < 2)
69  {
70  return nullptr;
71  }
72  return static_cast<ConfidenceImageType*>(this->itk::ProcessObject::GetOutput(1));
73 }
74 
75 template <class TInputImage, class TOutputImage, class TMaskImage>
77 {
78  if (m_BatchMode)
79  {
80 #ifdef _OPENMP
81  // OpenMP will take care of threading
82  this->SetNumberOfWorkUnits(1);
83 #endif
84  }
85 }
86 
87 template <class TInputImage, class TOutputImage, class TMaskImage>
89  itk::ThreadIdType threadId)
90 {
91  // Get the input pointers
92  InputImageConstPointerType inputPtr = this->GetInput();
93  MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
94  OutputImagePointerType outputPtr = this->GetOutput();
95  ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence();
96 
97  // Progress reporting
98  itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
99 
100  // Define iterators
101  typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
102  typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
103 
104  InputIteratorType inIt(inputPtr, outputRegionForThread);
105  OutputIteratorType outIt(outputPtr, outputRegionForThread);
106 
107  // Walk the part of the image
108  for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt)
109  {
110  // Classifify
111  outIt.Set(m_Model->Predict(inIt.Get()));
112  progress.CompletedPixel();
113  }
114 }
115 
116 template <class TInputImage, class TOutputImage, class TMaskImage>
118 {
119  Superclass::GenerateOutputInformation();
120  if (!m_Model)
121  {
122  itkGenericExceptionMacro(<< "No model for dimensionality reduction");
123  }
124  this->GetOutput()->SetNumberOfComponentsPerPixel(m_Model->GetDimension());
125 }
126 
127 template <class TInputImage, class TOutputImage, class TMaskImage>
129  itk::ThreadIdType threadId)
130 {
131  // Get the input pointers
132  InputImageConstPointerType inputPtr = this->GetInput();
133  MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
134  OutputImagePointerType outputPtr = this->GetOutput();
135  ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence();
136 
137  // Progress reporting
138  itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
139 
140  // Define iterators
141  typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
142  typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
143 
144  InputIteratorType inIt(inputPtr, outputRegionForThread);
145  OutputIteratorType outIt(outputPtr, outputRegionForThread);
146 
147  typedef typename ModelType::InputSampleType InputSampleType;
148  typedef typename ModelType::InputListSampleType InputListSampleType;
149  typedef typename ModelType::TargetValueType TargetValueType;
150  typedef typename ModelType::TargetListSampleType TargetListSampleType;
151 
152  typename InputListSampleType::Pointer samples = InputListSampleType::New();
153  unsigned int num_features = inputPtr->GetNumberOfComponentsPerPixel();
154  samples->SetMeasurementVectorSize(num_features);
155  InputSampleType sample(num_features);
156 
157  // Fill the samples
158  for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt)
159  {
160  typename InputImageType::PixelType pix = inIt.Get();
161  for (size_t feat = 0; feat < num_features; ++feat)
162  {
163  sample[feat] = pix[feat];
164  }
165  samples->PushBack(sample);
166  }
167  // Make the batch prediction
168  typename TargetListSampleType::Pointer labels;
169 
170  // This call is threadsafe
171  labels = m_Model->PredictBatch(samples);
172 
173  // Set the output values
174  typename TargetListSampleType::ConstIterator labIt = labels->Begin();
175  for (outIt.GoToBegin(); !outIt.IsAtEnd(); ++outIt)
176  {
177  itk::VariableLengthVector<TargetValueType> labelValue;
178  labelValue = labIt.GetMeasurementVector();
179  ++labIt;
180  outIt.Set(labelValue);
181  progress.CompletedPixel();
182  }
183 }
184 
185 template <class TInputImage, class TOutputImage, class TMaskImage>
187  itk::ThreadIdType threadId)
188 {
189  if (m_BatchMode)
190  {
191  this->BatchThreadedGenerateData(outputRegionForThread, threadId);
192  }
193  else
194  {
195  this->ClassicThreadedGenerateData(outputRegionForThread, threadId);
196  }
197 }
198 
202 template <class TInputImage, class TOutputImage, class TMaskImage>
204 {
205  Superclass::PrintSelf(os, indent);
206 }
207 
208 } // End namespace otb
209 #endif
void PrintSelf(std::ostream &os, itk::Indent indent) const override
void BatchThreadedGenerateData(const OutputImageRegionType &outputRegionForThread, itk::ThreadIdType threadId)
void ClassicThreadedGenerateData(const OutputImageRegionType &outputRegionForThread, itk::ThreadIdType threadId)
void ThreadedGenerateData(const OutputImageRegionType &outputRegionForThread, itk::ThreadIdType threadId) override
Creation of an "otb" image which contains metadata.
Definition: otbImage.h:92
itk::Statistics::ListSample< TargetSampleType > TargetListSampleType
MLMTargetTraits< TTargetValue >::ValueType TargetValueType
itk::Statistics::ListSample< InputSampleType > InputListSampleType
MLMSampleTraits< TInputValue >::SampleType InputSampleType
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.