22 #ifndef otbMeanShiftSmoothingImageFilter_hxx
23 #define otbMeanShiftSmoothingImageFilter_hxx
26 #include "itkImageRegionIterator.h"
30 #include "itkProgressReporter.h"
35 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
37 : m_RangeBandwidth(16.),
38 m_RangeBandwidthRamp(0),
43 m_MaxIterationNumber(10)
46 m_NumberOfComponentsPerPixel(0)
51 m_ThreadIdNumberOfBits(0)
53 , m_BucketOptimization(false)
56 this->SetNumberOfRequiredOutputs(4);
57 this->SetNthOutput(0, OutputImageType::New());
59 this->SetNthOutput(2, OutputIterationImageType::New());
64 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
69 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
76 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
80 return static_cast<OutputSpatialImageType*
>(this->itk::ProcessObject::GetOutput(1));
83 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
87 return static_cast<const OutputImageType*
>(this->itk::ProcessObject::GetOutput(0));
90 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
94 return static_cast<OutputImageType*
>(this->itk::ProcessObject::GetOutput(0));
97 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
104 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
108 return static_cast<OutputIterationImageType*
>(this->itk::ProcessObject::GetOutput(2));
111 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
118 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
122 return static_cast<OutputLabelImageType*
>(this->itk::ProcessObject::GetOutput(3));
125 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
129 typename OutputImageType::Pointer rangeOutputPtr = this->GetRangeOutput();
130 typename OutputIterationImageType::Pointer iterationOutputPtr = this->GetIterationOutput();
133 spatialOutputPtr->SetBufferedRegion(spatialOutputPtr->GetRequestedRegion());
134 spatialOutputPtr->Allocate();
136 rangeOutputPtr->SetBufferedRegion(rangeOutputPtr->GetRequestedRegion());
137 rangeOutputPtr->Allocate();
139 iterationOutputPtr->SetBufferedRegion(iterationOutputPtr->GetRequestedRegion());
140 iterationOutputPtr->Allocate();
142 labelOutputPtr->SetBufferedRegion(labelOutputPtr->GetRequestedRegion());
143 labelOutputPtr->Allocate();
146 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
149 Superclass::GenerateOutputInformation();
151 m_NumberOfComponentsPerPixel = this->GetInput()->GetNumberOfComponentsPerPixel();
153 if (this->GetSpatialOutput())
155 this->GetSpatialOutput()->SetNumberOfComponentsPerPixel(ImageDimension);
157 if (this->GetRangeOutput())
159 this->GetRangeOutput()->SetNumberOfComponentsPerPixel(m_NumberOfComponentsPerPixel);
163 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
167 Superclass::GenerateInputRequestedRegion();
174 if (!inPtr || !outRangePtr)
181 RegionType outputRequestedRegion = outRangePtr->GetRequestedRegion();
184 RegionType inputRequestedRegion = outputRequestedRegion;
187 m_SpatialRadius.Fill(m_Kernel.GetRadius(m_SpatialBandwidth));
191 for (
unsigned int comp = 0; comp < ImageDimension; ++comp)
193 margin[comp] = (m_MaxIterationNumber * m_SpatialRadius[comp]) + 1;
196 inputRequestedRegion.PadByRadius(margin);
199 if (inputRequestedRegion.Crop(inPtr->GetLargestPossibleRegion()))
201 inPtr->SetRequestedRegion(inputRequestedRegion);
210 inPtr->SetRequestedRegion(inputRequestedRegion);
213 itk::InvalidRequestedRegionError e(__FILE__, __LINE__);
214 e.SetLocation(ITK_LOCATION);
215 e.SetDescription(
"Requested region is (at least partially) outside the largest possible region.");
216 e.SetDataObject(inPtr);
221 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
229 typename InputImageType::ConstPointer inputPtr = this->GetInput();
230 typename OutputIterationImageType::Pointer iterationOutput = this->GetIterationOutput();
236 m_SpatialRadius.Fill(m_Kernel.GetRadius(m_SpatialBandwidth));
238 m_NumberOfComponentsPerPixel = this->GetInput()->GetNumberOfComponentsPerPixel();
241 this->AllocateOutputs();
244 iterationOutput->FillBuffer(0);
247 spatialOutput->FillBuffer(zero);
256 typename JointImageFunctorType::Pointer jointImageFunctor = JointImageFunctorType::New();
258 jointImageFunctor->SetInput(inputPtr);
259 jointImageFunctor->GetFunctor().Initialize(ImageDimension, m_NumberOfComponentsPerPixel, m_GlobalShift);
260 jointImageFunctor->GetOutput()->SetRequestedRegion(this->GetInput()->GetBufferedRegion());
261 jointImageFunctor->Update();
262 m_JointImage = jointImageFunctor->GetOutput();
265 if (m_BucketOptimization)
271 m_JointImage->GetRequestedRegion(), m_Kernel.GetRadius(m_SpatialBandwidth), 1,
311 m_ModeTable = ModeTableImageType::New();
312 m_ModeTable->SetRegions(inputPtr->GetRequestedRegion());
313 m_ModeTable->Allocate();
314 m_ModeTable->FillBuffer(0);
327 unsigned int numThreads;
329 numThreads = this->GetNumberOfThreads();
330 m_ThreadIdNumberOfBits = -1;
331 unsigned int n = numThreads;
335 m_ThreadIdNumberOfBits++;
337 if (m_ThreadIdNumberOfBits == 0)
338 m_ThreadIdNumberOfBits = 1;
339 m_NumLabels.SetSize(numThreads);
340 for (
unsigned int i = 0; i < numThreads; i++)
342 m_NumLabels[i] =
static_cast<LabelType>(i) << (
sizeof(
LabelType) * 8 - m_ThreadIdNumberOfBits);
348 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
353 const unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;
359 assert(meanShiftVector.GetSize() == jointDimension);
360 meanShiftVector.Fill(0);
363 for (
unsigned int comp = 0; comp < ImageDimension; ++comp)
365 inputIndex[comp] = std::floor(jointPixel[comp] + 0.5) - m_GlobalShift[comp];
368 std::max(
static_cast<long int>(outputRegion.GetIndex().GetElement(comp)),
static_cast<long int>(inputIndex[comp] - m_SpatialRadius[comp] - 1));
369 const long int indexRight = std::min(
static_cast<long int>(outputRegion.GetIndex().GetElement(comp) + outputRegion.GetSize().GetElement(comp) - 1),
370 static_cast<long int>(inputIndex[comp] + m_SpatialRadius[comp] + 1));
372 regionSize[comp] = std::max(0l, indexRight -
static_cast<long int>(regionIndex[comp]) + 1);
376 neighborhoodRegion.SetIndex(regionIndex);
377 neighborhoodRegion.SetSize(regionSize);
388 while (!it.IsAtEnd())
395 for (
unsigned int comp = 0; comp < jointDimension; comp++)
397 shifts[comp] = jointNeighbor[comp] - jointPixel[comp];
398 double d = shifts[comp] / bandwidth[comp];
403 const RealType weight = m_Kernel(norm2);
439 for (
unsigned int comp = 0; comp < jointDimension; comp++)
441 meanShiftVector[comp] += weight * shifts[comp];
449 for (
unsigned int comp = 0; comp < jointDimension; comp++)
451 meanShiftVector[comp] = meanShiftVector[comp] / weightSum;
458 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
460 const RealVector& jointPixel,
461 RealVector& meanShiftVector)
463 const unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;
465 RealType weightSum = 0;
467 for (
unsigned int comp = 0; comp < jointDimension; comp++)
469 meanShiftVector[comp] = 0;
472 RealVector jointNeighbor(ImageDimension + m_NumberOfComponentsPerPixel);
474 InputIndexType index;
475 for (
unsigned int dim = 0; dim < ImageDimension; ++dim)
477 index[dim] = jointPixel[dim] * m_SpatialBandwidth + 0.5;
480 const std::vector<unsigned int>
481 neighborBuckets = m_BucketImage.GetNeighborhoodBucketListIndices(
482 m_BucketImage.BucketIndexToBucketListIndex(
483 m_BucketImage.GetBucketIndex(
487 unsigned int numNeighbors = m_BucketImage.GetNumberOfNeighborBuckets();
488 for (
unsigned int neighborIndex = 0; neighborIndex < numNeighbors; ++neighborIndex)
490 const typename BucketImageType::BucketType & bucket = m_BucketImage.GetBucket(neighborBuckets[neighborIndex]);
491 if (bucket.empty())
continue;
492 typename BucketImageType::BucketType::const_iterator it = bucket.begin();
493 while (it != bucket.end())
495 jointNeighbor.SetData(
const_cast<RealType*
> (*it));
500 for (
unsigned int comp = 0; comp < jointDimension; comp++)
502 const RealType d = jointNeighbor[comp] - jointPixel[comp];
507 const RealType weight = m_Kernel(norm2);
513 for (
unsigned int comp = 0; comp < jointDimension; comp++)
515 meanShiftVector[comp] += weight * jointNeighbor[comp];
524 for (
unsigned int comp = 0; comp < jointDimension; comp++)
526 meanShiftVector[comp] = meanShiftVector[comp] / weightSum - jointPixel[comp];
532 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
541 typename OutputImageType::Pointer rangeOutput = this->GetRangeOutput();
542 typename OutputIterationImageType::Pointer iterationOutput = this->GetIterationOutput();
546 typename InputImageType::ConstPointer input = this->GetInput();
549 typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
550 typedef itk::ImageRegionIterator<OutputSpatialImageType> OutputSpatialIteratorType;
551 typedef itk::ImageRegionIterator<OutputIterationImageType> OutputIterationIteratorType;
552 typedef itk::ImageRegionIterator<OutputLabelImageType> OutputLabelIteratorType;
554 const unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;
556 typename OutputImageType::PixelType rangePixel(m_NumberOfComponentsPerPixel);
562 for (
unsigned int comp = 0; comp < ImageDimension; comp++)
563 bandwidth[comp] = m_SpatialBandwidth;
565 itk::ProgressReporter progress(
this, threadId, outputRegionForThread.GetNumberOfPixels());
567 RegionType const& requestedRegion = input->GetRequestedRegion();
569 typedef itk::ImageRegionConstIteratorWithIndex<RealVectorImageType> JointImageIteratorType;
570 JointImageIteratorType jointIt(m_JointImage, outputRegionForThread);
572 OutputIteratorType rangeIt(rangeOutput, outputRegionForThread);
573 OutputSpatialIteratorType spatialIt(spatialOutput, outputRegionForThread);
574 OutputIterationIteratorType iterationIt(iterationOutput, outputRegionForThread);
575 OutputLabelIteratorType labelIt(labelOutput, outputRegionForThread);
577 typedef itk::ImageRegionIterator<ModeTableImageType> ModeTableImageIteratorType;
578 ModeTableImageIteratorType modeTableIt(m_ModeTable, outputRegionForThread);
582 spatialIt.GoToBegin();
583 iterationIt.GoToBegin();
584 modeTableIt.GoToBegin();
587 unsigned int iteration = 0;
594 std::vector<InputIndexType> pointList;
596 pointList.resize(m_MaxIterationNumber);
599 unsigned int numBreaks = 0;
603 for (; !jointIt.IsAtEnd(); ++jointIt, ++rangeIt, ++spatialIt, ++iterationIt, ++modeTableIt, ++labelIt, progress.CompletedPixel())
608 if (m_ModeSearch && currentPixelMode == 1)
614 bool hasConverged =
false;
618 const RealVector& jointPixelVal = jointIt.Get();
619 for (
unsigned int comp = 0; comp < jointDimension; comp++)
620 jointPixel[comp] = jointPixelVal[comp];
622 for (
unsigned int comp = ImageDimension; comp < jointDimension; comp++)
623 bandwidth[comp] = m_RangeBandwidthRamp * jointPixel[comp] + m_RangeBandwidth;
629 unsigned int pointCount = 0;
631 while ((iteration < m_MaxIterationNumber) && (!hasConverged))
637 for (
unsigned int comp = 0; comp < ImageDimension; comp++)
639 modeCandidate[comp] = std::floor(jointPixel[comp] - m_GlobalShift[comp] + 0.5);
647 if (modeCandidate != currentIndex && m_ModeTable->GetPixel(modeCandidate) != 2 && outputRegionForThread.IsInside(modeCandidate))
651 RealVector const& candidatePixel = m_JointImage->GetPixel(modeCandidate);
652 for (
unsigned int comp = ImageDimension; comp < jointDimension; comp++)
654 const RealType d = (candidatePixel[comp] - jointPixel[comp]) / bandwidth[comp];
662 if (m_ModeTable->GetPixel(modeCandidate) == 0)
666 pointList[pointCount++] = modeCandidate;
667 m_ModeTable->SetPixel(modeCandidate, 2);
673 rangePixel = rangeOutput->GetPixel(modeCandidate);
674 for (
unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
676 jointPixel[ImageDimension + comp] = rangePixel[comp];
690 if (m_BucketOptimization)
692 this->CalculateMeanShiftVectorBucket(jointPixel, meanShiftVector);
697 this->CalculateMeanShiftVector(m_JointImage, jointPixel, requestedRegion, bandwidth, meanShiftVector);
705 double meanShiftVectorSqNorm = 0;
706 for (
unsigned int comp = 0; comp < jointDimension; comp++)
708 const double v = meanShiftVector[comp];
709 meanShiftVectorSqNorm += v * v;
710 jointPixel[comp] += meanShiftVector[comp];
714 hasConverged = meanShiftVectorSqNorm < m_Threshold;
718 for (
unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
720 rangePixel[comp] = jointPixel[ImageDimension + comp];
723 for (
unsigned int comp = 0; comp < ImageDimension; comp++)
725 spatialPixel[comp] = jointPixel[comp] - currentIndex[comp] - m_GlobalShift[comp];
728 rangeIt.Set(rangePixel);
729 spatialIt.Set(spatialPixel);
731 const typename OutputIterationImageType::PixelType iterationPixel = iteration;
732 iterationIt.Set(iterationPixel);
741 if (hasConverged || iteration == m_MaxIterationNumber)
743 m_NumLabels[threadId]++;
744 label = m_NumLabels[threadId];
748 label = labelOutput->GetPixel(modeCandidate);
753 for (
unsigned int i = 0; i < pointCount; i++)
755 rangeOutput->SetPixel(pointList[i], rangePixel);
756 m_ModeTable->SetPixel(pointList[i], 1);
757 labelOutput->SetPixel(pointList[i], label);
763 labelIt.Set(labelZero);
770 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
774 typedef itk::ImageRegionIterator<OutputLabelImageType> OutputLabelIteratorType;
775 OutputLabelIteratorType labelIt(labelOutput, labelOutput->GetRequestedRegion());
783 itk::VariableLengthVector<LabelType> newLabelOffset;
784 newLabelOffset.SetSize(this->GetNumberOfThreads());
785 newLabelOffset[0] = 0;
786 for (itk::ThreadIdType i = 1; i < this->GetNumberOfThreads(); i++)
791 m_NumLabels[i - 1] & ((
static_cast<LabelType>(1) << (
sizeof(
LabelType) * 8 - m_ThreadIdNumberOfBits)) -
static_cast<LabelType>(1));
792 newLabelOffset[i] = localNumLabel + newLabelOffset[i - 1];
797 while (!labelIt.IsAtEnd())
802 const itk::ThreadIdType threadId = label >> (
sizeof(
LabelType) * 8 - m_ThreadIdNumberOfBits);
808 newLabel += newLabelOffset[threadId];
810 labelIt.Set(newLabel);
816 template <
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
819 Superclass::PrintSelf(os, indent);
820 os << indent <<
"Spatial bandwidth: " << m_SpatialBandwidth << std::endl;
821 os << indent <<
"Range bandwidth: " << m_RangeBandwidth << std::endl;