22 #ifndef otbMeanShiftSmoothingImageFilter_hxx
23 #define otbMeanShiftSmoothingImageFilter_hxx
26 #include "itkImageRegionIterator.h"
30 #include "itkProgressReporter.h"
35 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
37 : m_RangeBandwidth(16.),
38 m_RangeBandwidthRamp(0),
43 m_MaxIterationNumber(10)
46 m_NumberOfComponentsPerPixel(0)
51 m_ThreadIdNumberOfBits(0)
53 , m_BucketOptimization(false)
56 this->DynamicMultiThreadingOff();
57 this->SetNumberOfRequiredOutputs(4);
58 this->SetNthOutput(0, OutputImageType::New());
60 this->SetNthOutput(2, OutputIterationImageType::New());
65 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
70 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
77 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
84 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
88 return static_cast<const OutputImageType*
>(this->itk::ProcessObject::GetOutput(0));
91 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
95 return static_cast<OutputImageType*
>(this->itk::ProcessObject::GetOutput(0));
98 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
105 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
112 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
119 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
126 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
130 typename OutputImageType::Pointer rangeOutputPtr = this->GetRangeOutput();
131 typename OutputIterationImageType::Pointer iterationOutputPtr = this->GetIterationOutput();
134 spatialOutputPtr->SetBufferedRegion(spatialOutputPtr->GetRequestedRegion());
135 spatialOutputPtr->Allocate();
137 rangeOutputPtr->SetBufferedRegion(rangeOutputPtr->GetRequestedRegion());
138 rangeOutputPtr->Allocate();
140 iterationOutputPtr->SetBufferedRegion(iterationOutputPtr->GetRequestedRegion());
141 iterationOutputPtr->Allocate();
143 labelOutputPtr->SetBufferedRegion(labelOutputPtr->GetRequestedRegion());
144 labelOutputPtr->Allocate();
147 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
150 Superclass::GenerateOutputInformation();
152 m_NumberOfComponentsPerPixel = this->GetInput()->GetNumberOfComponentsPerPixel();
154 if (this->GetSpatialOutput())
156 this->GetSpatialOutput()->SetNumberOfComponentsPerPixel(ImageDimension);
158 if (this->GetRangeOutput())
160 this->GetRangeOutput()->SetNumberOfComponentsPerPixel(m_NumberOfComponentsPerPixel);
164 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
168 Superclass::GenerateInputRequestedRegion();
175 if (!inPtr || !outRangePtr)
182 RegionType outputRequestedRegion = outRangePtr->GetRequestedRegion();
185 RegionType inputRequestedRegion = outputRequestedRegion;
188 m_SpatialRadius.Fill(m_Kernel.GetRadius(m_SpatialBandwidth));
192 for (
unsigned int comp = 0; comp < ImageDimension; ++comp)
194 margin[comp] = (m_MaxIterationNumber * m_SpatialRadius[comp]) + 1;
197 inputRequestedRegion.PadByRadius(margin);
200 if (inputRequestedRegion.Crop(inPtr->GetLargestPossibleRegion()))
202 inPtr->SetRequestedRegion(inputRequestedRegion);
211 inPtr->SetRequestedRegion(inputRequestedRegion);
214 itk::InvalidRequestedRegionError e(__FILE__, __LINE__);
215 e.SetLocation(ITK_LOCATION);
216 e.SetDescription(
"Requested region is (at least partially) outside the largest possible region.");
217 e.SetDataObject(inPtr);
222 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
230 typename InputImageType::ConstPointer inputPtr = this->GetInput();
231 typename OutputIterationImageType::Pointer iterationOutput = this->GetIterationOutput();
237 m_SpatialRadius.Fill(m_Kernel.GetRadius(m_SpatialBandwidth));
239 m_NumberOfComponentsPerPixel = this->GetInput()->GetNumberOfComponentsPerPixel();
242 this->AllocateOutputs();
245 iterationOutput->FillBuffer(0);
248 spatialOutput->FillBuffer(zero);
257 typename JointImageFunctorType::Pointer jointImageFunctor = JointImageFunctorType::New();
259 jointImageFunctor->SetInput(inputPtr);
260 jointImageFunctor->GetFunctor().Initialize(ImageDimension, m_NumberOfComponentsPerPixel, m_GlobalShift);
261 jointImageFunctor->GetOutput()->SetRequestedRegion(this->GetInput()->GetBufferedRegion());
262 jointImageFunctor->Update();
263 m_JointImage = jointImageFunctor->GetOutput();
266 if (m_BucketOptimization)
272 m_JointImage->GetRequestedRegion(), m_Kernel.GetRadius(m_SpatialBandwidth), 1,
312 m_ModeTable = ModeTableImageType::New();
313 m_ModeTable->SetRegions(inputPtr->GetRequestedRegion());
314 m_ModeTable->Allocate();
315 m_ModeTable->FillBuffer(0);
328 unsigned int numThreads;
330 numThreads = this->GetNumberOfWorkUnits();
331 m_ThreadIdNumberOfBits = -1;
332 unsigned int n = numThreads;
336 m_ThreadIdNumberOfBits++;
338 if (m_ThreadIdNumberOfBits == 0)
339 m_ThreadIdNumberOfBits = 1;
340 m_NumLabels.SetSize(numThreads);
341 for (
unsigned int i = 0; i < numThreads; i++)
343 m_NumLabels[i] =
static_cast<LabelType>(i) << (
sizeof(
LabelType) * 8 - m_ThreadIdNumberOfBits);
349 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
354 const unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;
360 assert(meanShiftVector.GetSize() == jointDimension);
361 meanShiftVector.Fill(0);
364 for (
unsigned int comp = 0; comp < ImageDimension; ++comp)
366 inputIndex[comp] = std::floor(jointPixel[comp] + 0.5) - m_GlobalShift[comp];
369 std::max(
static_cast<long int>(outputRegion.GetIndex().GetElement(comp)),
static_cast<long int>(inputIndex[comp] - m_SpatialRadius[comp] - 1));
370 const long int indexRight = std::min(
static_cast<long int>(outputRegion.GetIndex().GetElement(comp) + outputRegion.GetSize().GetElement(comp) - 1),
371 static_cast<long int>(inputIndex[comp] + m_SpatialRadius[comp] + 1));
373 regionSize[comp] = std::max(0l, indexRight -
static_cast<long int>(regionIndex[comp]) + 1);
377 neighborhoodRegion.SetIndex(regionIndex);
378 neighborhoodRegion.SetSize(regionSize);
389 while (!it.IsAtEnd())
396 for (
unsigned int comp = 0; comp < jointDimension; comp++)
398 shifts[comp] = jointNeighbor[comp] - jointPixel[comp];
399 double d = shifts[comp] / bandwidth[comp];
404 const RealType weight = m_Kernel(norm2);
440 for (
unsigned int comp = 0; comp < jointDimension; comp++)
442 meanShiftVector[comp] += weight * shifts[comp];
450 for (
unsigned int comp = 0; comp < jointDimension; comp++)
452 meanShiftVector[comp] = meanShiftVector[comp] / weightSum;
459 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
461 const RealVector& jointPixel,
462 RealVector& meanShiftVector)
464 const unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;
466 RealType weightSum = 0;
468 for (
unsigned int comp = 0; comp < jointDimension; comp++)
470 meanShiftVector[comp] = 0;
473 RealVector jointNeighbor(ImageDimension + m_NumberOfComponentsPerPixel);
475 InputIndexType index;
476 for (
unsigned int dim = 0; dim < ImageDimension; ++dim)
478 index[dim] = jointPixel[dim] * m_SpatialBandwidth + 0.5;
481 const std::vector<unsigned int>
482 neighborBuckets = m_BucketImage.GetNeighborhoodBucketListIndices(
483 m_BucketImage.BucketIndexToBucketListIndex(
484 m_BucketImage.GetBucketIndex(
488 unsigned int numNeighbors = m_BucketImage.GetNumberOfNeighborBuckets();
489 for (
unsigned int neighborIndex = 0; neighborIndex < numNeighbors; ++neighborIndex)
491 const typename BucketImageType::BucketType & bucket = m_BucketImage.GetBucket(neighborBuckets[neighborIndex]);
492 if (bucket.empty())
continue;
493 typename BucketImageType::BucketType::const_iterator it = bucket.begin();
494 while (it != bucket.end())
496 jointNeighbor.SetData(
const_cast<RealType*
> (*it));
501 for (
unsigned int comp = 0; comp < jointDimension; comp++)
503 const RealType d = jointNeighbor[comp] - jointPixel[comp];
508 const RealType weight = m_Kernel(norm2);
514 for (
unsigned int comp = 0; comp < jointDimension; comp++)
516 meanShiftVector[comp] += weight * jointNeighbor[comp];
525 for (
unsigned int comp = 0; comp < jointDimension; comp++)
527 meanShiftVector[comp] = meanShiftVector[comp] / weightSum - jointPixel[comp];
533 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
542 typename OutputImageType::Pointer rangeOutput = this->GetRangeOutput();
543 typename OutputIterationImageType::Pointer iterationOutput = this->GetIterationOutput();
547 typename InputImageType::ConstPointer input = this->GetInput();
550 typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
551 typedef itk::ImageRegionIterator<OutputSpatialImageType> OutputSpatialIteratorType;
552 typedef itk::ImageRegionIterator<OutputIterationImageType> OutputIterationIteratorType;
553 typedef itk::ImageRegionIterator<OutputLabelImageType> OutputLabelIteratorType;
555 const unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;
557 typename OutputImageType::PixelType rangePixel(m_NumberOfComponentsPerPixel);
563 for (
unsigned int comp = 0; comp < ImageDimension; comp++)
564 bandwidth[comp] = m_SpatialBandwidth;
566 itk::ProgressReporter progress(
this, threadId, outputRegionForThread.GetNumberOfPixels());
568 RegionType const& requestedRegion = input->GetRequestedRegion();
570 typedef itk::ImageRegionConstIteratorWithIndex<RealVectorImageType> JointImageIteratorType;
571 JointImageIteratorType jointIt(m_JointImage, outputRegionForThread);
573 OutputIteratorType rangeIt(rangeOutput, outputRegionForThread);
574 OutputSpatialIteratorType spatialIt(spatialOutput, outputRegionForThread);
575 OutputIterationIteratorType iterationIt(iterationOutput, outputRegionForThread);
576 OutputLabelIteratorType labelIt(labelOutput, outputRegionForThread);
578 typedef itk::ImageRegionIterator<ModeTableImageType> ModeTableImageIteratorType;
579 ModeTableImageIteratorType modeTableIt(m_ModeTable, outputRegionForThread);
583 spatialIt.GoToBegin();
584 iterationIt.GoToBegin();
585 modeTableIt.GoToBegin();
588 unsigned int iteration = 0;
595 std::vector<InputIndexType> pointList;
597 pointList.resize(m_MaxIterationNumber);
600 unsigned int numBreaks = 0;
604 for (; !jointIt.IsAtEnd(); ++jointIt, ++rangeIt, ++spatialIt, ++iterationIt, ++modeTableIt, ++labelIt, progress.CompletedPixel())
609 if (m_ModeSearch && currentPixelMode == 1)
615 bool hasConverged =
false;
619 const RealVector& jointPixelVal = jointIt.Get();
620 for (
unsigned int comp = 0; comp < jointDimension; comp++)
621 jointPixel[comp] = jointPixelVal[comp];
623 for (
unsigned int comp = ImageDimension; comp < jointDimension; comp++)
624 bandwidth[comp] = m_RangeBandwidthRamp * jointPixel[comp] + m_RangeBandwidth;
630 unsigned int pointCount = 0;
632 while ((iteration < m_MaxIterationNumber) && (!hasConverged))
638 for (
unsigned int comp = 0; comp < ImageDimension; comp++)
640 modeCandidate[comp] = std::floor(jointPixel[comp] - m_GlobalShift[comp] + 0.5);
648 if (modeCandidate != currentIndex && m_ModeTable->GetPixel(modeCandidate) != 2 && outputRegionForThread.IsInside(modeCandidate))
652 RealVector const& candidatePixel = m_JointImage->GetPixel(modeCandidate);
653 for (
unsigned int comp = ImageDimension; comp < jointDimension; comp++)
655 const RealType d = (candidatePixel[comp] - jointPixel[comp]) / bandwidth[comp];
663 if (m_ModeTable->GetPixel(modeCandidate) == 0)
667 pointList[pointCount++] = modeCandidate;
668 m_ModeTable->SetPixel(modeCandidate, 2);
674 rangePixel = rangeOutput->GetPixel(modeCandidate);
675 for (
unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
677 jointPixel[ImageDimension + comp] = rangePixel[comp];
691 if (m_BucketOptimization)
693 this->CalculateMeanShiftVectorBucket(jointPixel, meanShiftVector);
698 this->CalculateMeanShiftVector(m_JointImage, jointPixel, requestedRegion, bandwidth, meanShiftVector);
706 double meanShiftVectorSqNorm = 0;
707 for (
unsigned int comp = 0; comp < jointDimension; comp++)
709 const double v = meanShiftVector[comp];
710 meanShiftVectorSqNorm += v * v;
711 jointPixel[comp] += meanShiftVector[comp];
715 hasConverged = meanShiftVectorSqNorm < m_Threshold;
719 for (
unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
721 rangePixel[comp] = jointPixel[ImageDimension + comp];
724 for (
unsigned int comp = 0; comp < ImageDimension; comp++)
726 spatialPixel[comp] = jointPixel[comp] - currentIndex[comp] - m_GlobalShift[comp];
729 rangeIt.Set(rangePixel);
730 spatialIt.Set(spatialPixel);
732 const typename OutputIterationImageType::PixelType iterationPixel = iteration;
733 iterationIt.Set(iterationPixel);
742 if (hasConverged || iteration == m_MaxIterationNumber)
744 m_NumLabels[threadId]++;
745 label = m_NumLabels[threadId];
749 label = labelOutput->GetPixel(modeCandidate);
754 for (
unsigned int i = 0; i < pointCount; i++)
756 rangeOutput->SetPixel(pointList[i], rangePixel);
757 m_ModeTable->SetPixel(pointList[i], 1);
758 labelOutput->SetPixel(pointList[i], label);
764 labelIt.Set(labelZero);
771 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
775 typedef itk::ImageRegionIterator<OutputLabelImageType> OutputLabelIteratorType;
776 OutputLabelIteratorType labelIt(labelOutput, labelOutput->GetRequestedRegion());
784 itk::VariableLengthVector<LabelType> newLabelOffset;
785 newLabelOffset.SetSize(this->GetNumberOfWorkUnits());
786 newLabelOffset[0] = 0;
787 for (itk::ThreadIdType i = 1; i < this->GetNumberOfWorkUnits(); i++)
792 m_NumLabels[i - 1] & ((
static_cast<LabelType>(1) << (
sizeof(
LabelType) * 8 - m_ThreadIdNumberOfBits)) -
static_cast<LabelType>(1));
793 newLabelOffset[i] = localNumLabel + newLabelOffset[i - 1];
798 while (!labelIt.IsAtEnd())
803 const itk::ThreadIdType threadId = label >> (
sizeof(
LabelType) * 8 - m_ThreadIdNumberOfBits);
809 newLabel += newLabelOffset[threadId];
811 labelIt.Set(newLabel);
817 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
820 Superclass::PrintSelf(os, indent);
821 os << indent <<
"Spatial bandwidth: " << m_SpatialBandwidth << std::endl;
822 os << indent <<
"Range bandwidth: " << m_RangeBandwidth << std::endl;
Creation of an "otb" image which contains metadata.
itk::SmartPointer< Self > Pointer
Superclass::InternalPixelType InternalPixelType
void AllocateOutputs() override
OutputImageType::Pointer OutputImagePointerType
itk::VariableLengthVector< RealType > RealVector
InputImageType::Pointer InputImagePointerType
~MeanShiftSmoothingImageFilter() override
void ThreadedGenerateData(const OutputRegionType &outputRegionForThread, itk::ThreadIdType threadId) override
const OutputSpatialImageType * GetSpatialOutput() const
const OutputLabelImageType * GetLabelOutput() const
InputIndexType m_GlobalShift
const OutputImageType * GetRangeOutput() const
void GenerateOutputInformation(void) override
InputImageType::RegionType RegionType
TOutputImage OutputImageType
const OutputIterationImageType * GetIterationOutput() const
MeanShiftSmoothingImageFilter()
virtual void CalculateMeanShiftVector(const typename RealVectorImageType::Pointer inputImagePtr, const RealVector &jointPixel, const OutputRegionType &outputRegion, const RealVector &bandwidth, RealVector &meanShiftVector)
void AfterThreadedGenerateData() override
OutputSpatialImageType::PixelType OutputSpatialPixelType
void GenerateInputRequestedRegion() override
InputImageType::SizeType InputSizeType
InputImageType::IndexType InputIndexType
OutputSpatialImageType::Pointer OutputSpatialImagePointerType
TOutputIterationImage OutputIterationImageType
void BeforeThreadedGenerateData() override
OutputImageType::RegionType OutputRegionType
void PrintSelf(std::ostream &os, itk::Indent indent) const override
const InternalPixelType * GetPixelPointer() const
Implements neighborhood-wise generic operation on image.
Creation of an "otb" vector image which contains metadata.
Superclass::PixelType PixelType
itk::SmartPointer< const Self > ConstPointer
itk::SmartPointer< Self > Pointer
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.