21 #ifndef otbListSampleToBalancedListSampleFilter_hxx
22 #define otbListSampleToBalancedListSampleFilter_hxx
25 #include "itkProgressReporter.h"
26 #include "itkHistogram.h"
27 #include "itkNumericTraits.h"
35 template <
class TInputSampleList,
class TLabelSampleList,
class TOutputSampleList>
38 this->SetNumberOfRequiredInputs(2);
39 this->SetNumberOfRequiredOutputs(2);
43 this->itk::ProcessObject::SetNthOutput(1, this->MakeOutput(1).GetPointer());
45 m_AddGaussianNoiseFilter = GaussianAdditiveNoiseType::New();
46 m_BalancingFactor = 5;
49 template <
class TInputSampleList,
class TLabelSampleList,
class TOutputSampleList>
57 Superclass::MakeOutput(0);
61 output =
static_cast<itk::DataObject*
>(LabelSampleListType::New().GetPointer());
65 output =
static_cast<itk::DataObject*
>(InputSampleListType::New().GetPointer());
72 template <
class TInputSampleList,
class TLabelSampleList,
class TOutputSampleList>
80 template <
class TInputSampleList,
class TLabelSampleList,
class TOutputSampleList>
84 if (this->GetNumberOfInputs() < 2)
93 template <
class TInputSampleList,
class TLabelSampleList,
class TOutputSampleList>
102 template <
class TInputSampleList,
class TLabelSampleList,
class TOutputSampleList>
106 LabelValueType maxLabel = itk::NumericTraits<LabelValueType>::min();
109 typename LabelSampleListType::ConstPointer labelPtr = this->GetInputLabel();
110 typename LabelSampleListType::ConstIterator labelIt = labelPtr->Begin();
112 while (labelIt != labelPtr->End())
117 if (currentInputMeasurement[0] > maxLabel)
118 maxLabel = currentInputMeasurement[0];
124 typedef typename itk::Statistics::Histogram<unsigned int> HistogramType;
125 typename HistogramType::Pointer histogram = HistogramType::New();
126 typename HistogramType::SizeType size(1);
127 size.Fill(maxLabel + 1);
128 histogram->SetMeasurementVectorSize(1);
129 histogram->Initialize(size);
131 labelIt = labelPtr->Begin();
132 while (labelIt != labelPtr->End())
136 histogram->IncreaseFrequency(currentInputMeasurement[0], 1);
141 unsigned int maxvalue = 0;
142 HistogramType::Iterator iter = histogram->Begin();
144 while (iter != histogram->End())
146 if (
static_cast<unsigned int>(iter.GetFrequency()) > maxvalue)
147 maxvalue =
static_cast<unsigned int>(iter.GetFrequency());
153 unsigned int balancedFrequency = m_BalancingFactor * maxvalue;
160 iter = histogram->Begin();
161 while (iter != histogram->End())
163 if (iter.GetFrequency() - 1e-10 < 0.)
164 m_MultiplicativeCoefficient.push_back(0);
167 unsigned int coeff =
static_cast<unsigned int>(balancedFrequency / iter.GetFrequency());
168 m_MultiplicativeCoefficient.push_back(coeff);
175 template <
class TInputSampleList,
class TLabelSampleList,
class TOutputSampleList>
179 this->ComputeMaxSampleFrequency();
188 outputSampleListPtr->Clear();
191 outputSampleListPtr->SetMeasurementVectorSize(inputSampleListPtr->GetMeasurementVectorSize());
192 outputLabel->SetMeasurementVectorSize(labelSampleListPtr->GetMeasurementVectorSize());
194 typename InputSampleListType::ConstIterator inputIt = inputSampleListPtr->Begin();
195 typename LabelSampleListType::ConstIterator labelIt = labelSampleListPtr->Begin();
198 itk::ProgressReporter progress(
this, 0, inputSampleListPtr->Size());
204 while (inputIt != inputSampleListPtr->End() && labelIt != labelSampleListPtr->End())
215 tempListSample->SetMeasurementVectorSize(inputSampleListPtr->GetMeasurementVectorSize());
216 tempListSample->PushBack(currentInputMeasurement);
219 unsigned int iterations = m_MultiplicativeCoefficient[currentLabelMeasurement[0]];
223 noisingFilter->SetInput(tempListSample);
224 noisingFilter->SetNumberOfIteration(iterations);
225 noisingFilter->Update();
229 currentOutputMeasurement.SetSize(currentInputMeasurement.GetSize());
232 for (
unsigned int idx = 0; idx < inputSampleListPtr->GetMeasurementVectorSize(); ++idx)
233 currentOutputMeasurement[idx] =
static_cast<OutputValueType>(currentInputMeasurement[idx]);
236 outputSampleListPtr->PushBack(currentOutputMeasurement);
239 outputLabel->PushBack(currentLabelMeasurement);
242 typename OutputSampleListType::ConstIterator tempIt = noisingFilter->GetOutput()->Begin();
244 while (tempIt != noisingFilter->GetOutput()->End())
249 outputSampleListPtr->PushBack(currentTempMeasurement);
252 outputLabel->PushBack(currentLabelMeasurement);
258 progress.CompletedPixel();
265 template <
class TInputSampleList,
class TLabelSampleList,
class TOutputSampleList>
269 Superclass::PrintSelf(os, indent);