17 #ifndef __itkMRFImageFilter_txx
18 #define __itkMRFImageFilter_txx
24 template<
class TInputImage,
class TClassifiedImage>
28 m_MaximumNumberOfIterations(50),
30 m_NeighborhoodSize(27),
31 m_TotalNumberOfValidPixelsInOutputImage(1),
32 m_TotalNumberOfPixelsInInputImage(1),
33 m_ErrorTolerance(0.2),
35 m_ClassProbability(0),
36 m_NumberOfIterations(0),
37 m_StopCondition(MaximumNumberOfIterations),
41 if( (
int)InputImageDimension != (
int)ClassifiedImageDimension )
44 msg <<
"Input image dimension: " << InputImageDimension <<
" != output image dimension: " << ClassifiedImageDimension;
45 throw ExceptionObject(__FILE__, __LINE__,msg.str().c_str(),ITK_LOCATION);
47 m_InputImageNeighborhoodRadius.Fill(0);
48 m_MRFNeighborhoodWeight.resize(0);
49 m_NeighborInfluence.resize(0);
50 m_DummyVector.resize(0);
51 this->SetMRFNeighborhoodWeight( m_DummyVector );
52 this->SetDefaultMRFNeighborhoodWeight();
55 template<
class TInputImage,
class TClassifiedImage>
62 template<
class TInputImage,
class TClassifiedImage>
67 Superclass::PrintSelf(os,indent);
70 os << indent <<
" MRF Image filter object " << std::endl;
72 os << indent <<
" Number of classes: " << m_NumberOfClasses << std::endl;
74 os << indent <<
" Maximum number of iterations: " <<
75 m_MaximumNumberOfIterations << std::endl;
77 os << indent <<
" Error tolerance for convergence: " <<
78 m_ErrorTolerance << std::endl;
80 os << indent <<
" Size of the MRF neighborhood radius:" <<
81 m_InputImageNeighborhoodRadius << std::endl;
83 os << indent <<
" Number of elements in MRF neighborhood :" <<
84 static_cast<unsigned long>( m_MRFNeighborhoodWeight.size() ) << std::endl;
86 os << indent <<
" Neighborhood weight : [";
87 const unsigned int neighborhoodWeightSize =
88 static_cast<unsigned int>( m_MRFNeighborhoodWeight.size() );
89 for (i=0; i+1 < neighborhoodWeightSize; i++)
91 os << m_MRFNeighborhoodWeight[i] <<
", ";
93 os << m_MRFNeighborhoodWeight[i] <<
"]" << std::endl;
95 os << indent <<
" Smoothing factor for the MRF neighborhood:" <<
96 m_SmoothingFactor << std::endl;
98 os << indent <<
"StopCondition: "
99 << m_StopCondition << std::endl;
101 os << indent <<
" Number of iterations: " <<
102 m_NumberOfIterations << std::endl;
109 template <
class TInputImage,
class TClassifiedImage>
120 if (inputPtr && outputPtr)
122 inputPtr->SetRequestedRegion( outputPtr->GetRequestedRegion() );
130 template <
class TInputImage,
class TClassifiedImage>
139 TClassifiedImage *imgData;
140 imgData =
dynamic_cast<TClassifiedImage*
>( output );
141 imgData->SetRequestedRegionToLargestPossibleRegion();
148 template <
class TInputImage,
class TClassifiedImage>
153 typename TInputImage::ConstPointer input = this->GetInput();
154 typename TClassifiedImage::Pointer output = this->GetOutput();
155 output->SetLargestPossibleRegion( input->GetLargestPossibleRegion() );
159 template<
class TInputImage,
class TClassifiedImage>
172 m_ClassifierPtr->SetInputImage( inputImage );
175 m_ClassifierPtr->Update();
180 this->ApplyMRFImageFilter();
185 outputPtr->SetBufferedRegion( outputPtr->GetRequestedRegion() );
186 outputPtr->Allocate();
194 labelledImageIt( m_ClassifierPtr->GetClassifiedImage(),
195 outputPtr->GetRequestedRegion() );
201 outImageIt( outputPtr, outputPtr->GetRequestedRegion() );
205 while ( !outImageIt.IsAtEnd() )
210 outImageIt.Set( labelvalue );
217 template<
class TInputImage,
class TClassifiedImage>
222 if( ( ptrToClassifier.
IsNull() ) || (m_NumberOfClasses <= 0) )
224 throw ExceptionObject(__FILE__, __LINE__,
"NumberOfClasses is <= 0",ITK_LOCATION);
226 m_ClassifierPtr = ptrToClassifier;
227 m_ClassifierPtr->SetNumberOfClasses( m_NumberOfClasses );
234 template<
class TInputImage,
class TClassifiedImage>
241 for(
unsigned int i=0;i < InputImageDimension; ++i)
243 radius[i] = radiusValue;
245 this->SetNeighborhoodRadius( radius );
249 template<
class TInputImage,
class TClassifiedImage>
255 for(
unsigned int i=0;i < InputImageDimension; ++i)
257 radius[i] = radiusArray[i];
260 this->SetNeighborhoodRadius( radius );
265 template<
class TInputImage,
class TClassifiedImage>
271 for(
unsigned int i=0;i < InputImageDimension; ++i)
273 m_InputImageNeighborhoodRadius[ i ] =
275 m_LabelledImageNeighborhoodRadius[ i ] =
277 m_LabelStatusImageNeighborhoodRadius[ i ] =
288 template<
class TInputImage,
class TClassifiedImage>
304 m_NeighborhoodSize = 1;
305 int neighborhoodRadius = 1;
306 for(
unsigned int i = 0; i < InputImageDimension; i++ )
308 m_NeighborhoodSize *= (2*neighborhoodRadius+1);
310 if( (InputImageDimension == 3) )
313 m_MRFNeighborhoodWeight.resize( m_NeighborhoodSize );
315 for(
int i = 0; i < 9; i++ )
316 m_MRFNeighborhoodWeight[i] = 1.3 * m_SmoothingFactor;
318 for(
int i = 9; i < 18; i++ )
319 m_MRFNeighborhoodWeight[i] = 1.7 * m_SmoothingFactor;
321 for(
int i = 18; i < 27; i++ )
322 m_MRFNeighborhoodWeight[i] = 1.3 * m_SmoothingFactor;
325 m_MRFNeighborhoodWeight[4] = 1.5 * m_SmoothingFactor;
326 m_MRFNeighborhoodWeight[13] = 0.0;
327 m_MRFNeighborhoodWeight[22] = 1.5 * m_SmoothingFactor;
329 if( (InputImageDimension == 2) )
332 m_MRFNeighborhoodWeight.resize( m_NeighborhoodSize );
334 for(
int i = 0; i < m_NeighborhoodSize; i++ )
335 m_MRFNeighborhoodWeight[i] = 1.7 * m_SmoothingFactor;
338 m_MRFNeighborhoodWeight[4] = 0;
340 if( (InputImageDimension > 3) )
342 for(
int i = 0; i < m_NeighborhoodSize; i++ )
344 m_MRFNeighborhoodWeight[i] = 1;
349 template<
class TInputImage,
class TClassifiedImage>
354 if( betaMatrix.size() == 0 )
357 this->SetDefaultMRFNeighborhoodWeight();
361 m_NeighborhoodSize = 1;
362 for(
unsigned int i = 0; i < InputImageDimension; i++ )
364 m_NeighborhoodSize *= (2*m_InputImageNeighborhoodRadius[i]+1);
367 if( m_NeighborhoodSize != static_cast<int>(betaMatrix.size()) )
369 throw ExceptionObject(__FILE__, __LINE__,
"NeighborhoodSize != betaMatrix.szie()", ITK_LOCATION);
374 m_MRFNeighborhoodWeight.resize( m_NeighborhoodSize );
376 for(
unsigned int i = 0; i < betaMatrix.size(); i++ )
378 m_MRFNeighborhoodWeight[i] = (betaMatrix[i] * m_SmoothingFactor);
388 template<
class TInputImage,
class TClassifiedImage>
393 if( m_NumberOfClasses <= 0 )
395 throw ExceptionObject(__FILE__, __LINE__,
"NumberOfClasses <= 0.",ITK_LOCATION);
404 for(
unsigned int i=0; i < InputImageDimension; i++ )
406 tmp =
static_cast<int>(inputImageSize[i]);
408 m_TotalNumberOfPixelsInInputImage *= tmp;
410 m_TotalNumberOfValidPixelsInOutputImage *=
411 ( tmp - 2*m_InputImageNeighborhoodRadius[i] );
419 region.SetSize( inputImageSize );
420 region.SetIndex( index );
422 m_LabelStatusImage = LabelStatusImageType::New();
423 m_LabelStatusImage->SetLargestPossibleRegion( region );
424 m_LabelStatusImage->SetBufferedRegion( region );
425 m_LabelStatusImage->Allocate();
428 m_LabelStatusImage->GetBufferedRegion() );
431 while( !rIter.IsAtEnd() )
443 template<
class TInputImage,
class TClassifiedImage>
450 this->GetInput()->GetBufferedRegion().GetSize();
452 int totalNumberOfPixelsInInputImage = 1;
454 for(
unsigned int i = 0; i < InputImageDimension; i++ )
456 totalNumberOfPixelsInInputImage *=
static_cast<int>(inputImageSize[ i ]);
459 int maxNumPixelError = Math::Round<int>( m_ErrorTolerance *
460 m_TotalNumberOfValidPixelsInOutputImage);
462 m_NumberOfIterations = 0;
465 itkDebugMacro(<<
"Iteration No." << m_NumberOfIterations);
467 MinimizeFunctional();
468 m_NumberOfIterations += 1;
469 m_ErrorCounter = m_TotalNumberOfValidPixelsInOutputImage -
470 totalNumberOfPixelsInInputImage;
473 m_LabelStatusImage->GetBufferedRegion() );
476 while( !rIter.IsAtEnd() )
478 if ( rIter.Get( ) == 1 ) m_ErrorCounter += 1;
482 while(( m_NumberOfIterations < m_MaximumNumberOfIterations ) &&
483 ( m_ErrorCounter > maxNumPixelError ) );
486 if( m_NumberOfIterations >= m_MaximumNumberOfIterations )
488 m_StopCondition = MaximumNumberOfIterations;
490 else if( m_ErrorCounter <= maxNumPixelError )
492 m_StopCondition = ErrorTolerance;
502 template<
class TInputImage,
class TClassifiedImage>
515 template<
class TInputImage,
class TClassifiedImage>
524 m_NeighborInfluence.resize( m_NumberOfClasses );
529 m_MahalanobisDistance.resize( m_NumberOfClasses );
550 inputImageFacesCalculator( inputImage,
551 inputImage->GetBufferedRegion(),
552 m_InputImageNeighborhoodRadius );
555 labelledImageFaceList =
556 labelledImageFacesCalculator( labelledImage,
557 labelledImage->GetBufferedRegion(),
558 m_LabelledImageNeighborhoodRadius );
560 labelStatusImageFaceList =
561 labelStatusImageFacesCalculator( m_LabelStatusImage,
562 m_LabelStatusImage->GetBufferedRegion(),
563 m_LabelStatusImageNeighborhoodRadius );
566 = inputImageFaceList.begin();
569 = labelledImageFaceList.begin();
572 = labelStatusImageFaceList.begin();
576 nInputImageNeighborhoodIter( m_InputImageNeighborhoodRadius,
578 *inputImageFaceListIter );
581 nLabelledImageNeighborhoodIter( m_LabelledImageNeighborhoodRadius,
583 *labelledImageFaceListIter );
586 nLabelStatusImageNeighborhoodIter( m_LabelStatusImageNeighborhoodRadius,
588 *labelStatusImageFaceListIter );
591 while( !nInputImageNeighborhoodIter.
IsAtEnd() )
595 this->DoNeighborhoodOperation( nInputImageNeighborhoodIter,
596 nLabelledImageNeighborhoodIter,
597 nLabelStatusImageNeighborhoodIter );
600 ++nInputImageNeighborhoodIter;
601 ++nLabelledImageNeighborhoodIter;
602 ++nLabelStatusImageNeighborhoodIter;
612 template<
class TInputImage,
class TClassifiedImage>
625 const std::vector<double> & pixelMembershipValue =
626 m_ClassifierPtr->GetPixelMembershipValue( *inputPixelVec );
631 for( index = 0; index < m_NeighborInfluence.size(); index++ )
633 m_NeighborInfluence[index]= 0;
639 for(
int i = 0; i < m_NeighborhoodSize; ++i )
642 labelledPixel = labelledIter.
GetPixel( i );
643 index = (
unsigned int) labelledPixel;
644 m_NeighborInfluence[ index ] += m_MRFNeighborhoodWeight[ i ];
649 for( index = 0; index < m_NumberOfClasses; index++ )
651 m_MahalanobisDistance[index] = m_NeighborInfluence[index] -
652 pixelMembershipValue[index];
656 double maximumDistance = -1e+20;
658 double tmpPixDistance;
659 for( index = 0; index < m_NumberOfClasses; index++ )
661 tmpPixDistance = m_MahalanobisDistance[index];
662 if ( tmpPixDistance > maximumDistance )
664 maximumDistance = tmpPixDistance;
675 if( pixLabel != (
int) ( *previousLabel ) )
678 for(
int i = 0; i < m_NeighborhoodSize; ++i )