17 #ifndef __itkMIRegistrationFunction_txx
18 #define __itkMIRegistrationFunction_txx
23 #include "vnl/vnl_math.h"
26 #include <vnl/vnl_matrix.h>
33 template <
class TFixedImage,
class TMovingImage,
class TDeformationField>
43 for( j = 0; j < ImageDimension; j++ )
46 m_NumberOfSamples *= (r[j]*2+1);
54 m_DenominatorThreshold = 1e-9;
55 m_IntensityDifferenceThreshold = 0.001;
56 this->SetMovingImage(
NULL);
57 this->SetFixedImage(
NULL);
58 m_FixedImageGradientCalculator = GradientCalculatorType::New();
65 m_MovingImageGradientCalculator = GradientCalculatorType::New();
68 DefaultInterpolatorType::New();
78 template <
class TFixedImage,
class TMovingImage,
class TDeformationField>
84 Superclass::PrintSelf(os, indent);
101 template <
class TFixedImage,
class TMovingImage,
class TDeformationField>
106 if( !this->m_MovingImage || !this->m_FixedImage || !m_MovingImageInterpolator )
108 itkExceptionMacro( <<
"MovingImage, FixedImage and/or Interpolator not set" );
112 m_FixedImageGradientCalculator->SetInputImage( this->m_FixedImage );
117 m_MovingImageGradientCalculator->SetInputImage( this->m_MovingImage );
120 m_MovingImageInterpolator->SetInputImage( this->m_MovingImage );
129 template <
class TFixedImage,
class TMovingImage,
class TDeformationField>
151 typedef vnl_matrix<double> matrixType;
152 typedef std::vector<double> sampleContainerType;
153 typedef std::vector<CovariantVectorType> gradContainerType;
154 typedef std::vector<double> gradMagContainerType;
155 typedef std::vector<unsigned int> inImageIndexContainerType;
166 for (indct=0;indct<ImageDimension;indct++)
169 derivative[indct]=0.0;
173 float thresh2=1.0/255.;
174 float thresh1=1.0/255.;
175 if ( this->m_MovingImage->GetPixel(oindex) <= thresh1 &&
176 this->m_FixedImage->GetPixel(oindex) <= thresh2 )
return update;
178 typename FixedImageType::SizeType hradius=this->GetRadius();
181 typename FixedImageType::SizeType imagesize=img->GetLargestPossibleRegion().GetSize();
187 sampleContainerType fixedSamplesA;
188 sampleContainerType movingSamplesA;
189 sampleContainerType fixedSamplesB;
190 sampleContainerType movingSamplesB;
191 inImageIndexContainerType inImageIndicesA;
192 gradContainerType fixedGradientsA;
193 gradMagContainerType fixedGradMagsA;
194 inImageIndexContainerType inImageIndicesB;
195 gradContainerType fixedGradientsB;
196 gradMagContainerType fixedGradMagsB;
198 unsigned int samplestep=2;
200 double minf=1.e9,minm=1.e9,maxf=0.0,maxm=0.0;
201 double movingMean=0.0;
202 double fixedMean=0.0;
203 double fixedValue=0,movingValue=0;
205 unsigned int sampct=0;
210 this->GetDeformationField(),
211 this->GetDeformationField()->GetRequestedRegion());
213 unsigned int hoodlen=asamIt.
Size();
216 for(indct=0; indct<hoodlen; indct=indct+samplestep)
220 for (
unsigned int dd=0; dd<ImageDimension; dd++)
222 if ( index[dd] < 0 || index[dd] > static_cast<typename IndexType::IndexValueType>(imagesize[dd]-1) ) inimage=
false;
232 fixedValue = (double) this->m_FixedImage->GetPixel( index );
235 fixedGradient = m_FixedImageGradientCalculator->EvaluateAtIndex( index );
238 typedef typename DeformationFieldType::PixelType DeformationPixelType;
239 const DeformationPixelType itvec = this->GetDeformationField()->GetPixel(index);
242 this->GetFixedImage()->TransformIndexToPhysicalPoint(index, mappedPoint);
243 for( j = 0; j < ImageDimension; j++ )
245 mappedPoint[j] += itvec[j];
247 if( m_MovingImageInterpolator->IsInsideBuffer( mappedPoint ) )
249 movingValue = m_MovingImageInterpolator->Evaluate( mappedPoint );
256 if (fixedValue > maxf) maxf=fixedValue;
257 else if (fixedValue < minf) minf=fixedValue;
258 if (movingValue > maxm) maxm=movingValue;
259 else if (movingValue < minm) minm=movingValue;
261 fixedMean += fixedValue;
262 movingMean += movingValue;
264 fixedSamplesA.insert(fixedSamplesA.begin(),(double)fixedValue);
265 fixedGradientsA.insert(fixedGradientsA.begin(),fixedGradient);
266 movingSamplesA.insert(movingSamplesA.begin(),(double)movingValue);
280 bool getrandasamples=
true;
284 typename FixedImageType::RegionType region=img->GetLargestPossibleRegion();
288 unsigned int numberOfSamples=10;
296 while( !randasamit.
IsAtEnd() && indct < numberOfSamples )
302 for (
unsigned int dd=0; dd<ImageDimension; dd++)
304 if ( index[dd] < 0 || index[dd] > static_cast<typename IndexType::IndexValueType>(imagesize[dd]-1) ) inimage=
false;
305 d += (index[dd]-oindex[dd])*(index[dd]-oindex[dd]);
315 fixedValue = (double) this->m_FixedImage->GetPixel( index );
316 fixedGradient = m_FixedImageGradientCalculator->EvaluateAtIndex( index );
318 for( j = 0; j < ImageDimension; j++ )
320 fgm += fixedGradient[j] *fixedGradient[j];
323 typedef typename DeformationFieldType::PixelType DeformationPixelType;
324 const DeformationPixelType itvec=this->GetDeformationField()->GetPixel(index);
326 this->GetFixedImage()->TransformIndexToPhysicalPoint(index, mappedPoint);
327 for( j = 0; j < ImageDimension; j++ )
329 mappedPoint[j] += itvec[j];
331 if( m_MovingImageInterpolator->IsInsideBuffer( mappedPoint ) )
333 movingValue = m_MovingImageInterpolator->Evaluate( mappedPoint );
342 if ( fixedValue > 0 || movingValue > 0 || fgm > 0 )
344 fixedMean += fixedValue;
345 movingMean += movingValue;
347 fixedSamplesA.insert(fixedSamplesA.begin(),(double)fixedValue);
348 fixedGradientsA.insert(fixedGradientsA.begin(),fixedGradient);
349 movingSamplesA.insert(movingSamplesA.begin(),(double)movingValue);
363 for (j=0;j<ImageDimension; j++)
368 hoodIt( hradius, field, field->GetRequestedRegion());
372 for(indct=0; indct<hoodIt.Size(); indct=indct+1)
374 const IndexType index=hoodIt.GetIndex(indct);
377 for (
unsigned int dd=0; dd<ImageDimension; dd++)
379 if ( index[dd] < 0 || index[dd] > static_cast<typename IndexType::IndexValueType>(imagesize[dd]-1) ) inimage=
false;
380 d += (index[dd]-oindex[dd])*(index[dd]-oindex[dd]);
382 if (inimage && vcl_sqrt(d) <= 1.0)
389 fixedValue = (double) this->m_FixedImage->GetPixel( index );
390 fixedGradient = m_FixedImageGradientCalculator->EvaluateAtIndex( index );
394 const typename DeformationFieldType::PixelType hooditvec=this->m_DeformationField->GetPixel(index);
396 this->GetFixedImage()->TransformIndexToPhysicalPoint(index, mappedPoint);
397 for(j = 0; j < ImageDimension; j++ )
399 mappedPoint[j] += hooditvec[j];
401 if( m_MovingImageInterpolator->IsInsideBuffer( mappedPoint ) )
403 movingValue = m_MovingImageInterpolator->Evaluate( mappedPoint );
410 fixedSamplesB.insert(fixedSamplesB.begin(),(double)fixedValue);
411 fixedGradientsB.insert(fixedGradientsB.begin(),fixedGradient);
412 movingSamplesB.insert(movingSamplesB.begin(),(double)movingValue);
418 double jointsigma=0.0;
420 const double numsamplesB= (double) fixedSamplesB.size();
421 const double numsamplesA= (double) fixedSamplesA.size();
422 double nsamp=numsamplesB;
428 fixedMean /= (double)sampct;
429 movingMean /= (double)sampct;
433 for(indct=0; indct<(
unsigned int)numsamplesA; indct++)
436 fixedValue=fixedSamplesA[indct];
437 movingValue=movingSamplesA[indct];
439 fsigma += (fixedValue-fixedMean)*(fixedValue-fixedMean);
440 msigma += (movingValue-movingMean)*(movingValue-movingMean);
441 jointsigma += fsigma+msigma;
445 fixedSamplesA[indct]=fixedSamplesA[indct]-minf;
446 movingSamplesA[indct]=movingSamplesA[indct]-minm;
447 if (indct < numsamplesB)
449 fixedSamplesB[indct]=fixedSamplesB[indct]-minf;
450 movingSamplesB[indct]=movingSamplesB[indct]-minm;
456 fsigma=vcl_sqrt(fsigma/numsamplesA);
458 double m_FixedImageStandardDeviation=fsigma*sigmaw;
459 msigma=vcl_sqrt(msigma/numsamplesA);
460 double m_MovingImageStandardDeviation=msigma*sigmaw;
461 jointsigma=vcl_sqrt(jointsigma/numsamplesA);
463 if (fsigma < 1.e-7 || msigma < 1.e-7 )
return update;
466 double m_MinProbability = 0.0001;
467 double dLogSumFixed=0.,dLogSumMoving=0.,dLogSumJoint=0.0;
468 unsigned int bsamples;
469 unsigned int asamples;
472 for(bsamples=0; bsamples<(
unsigned int)numsamplesB; bsamples++)
474 double dDenominatorMoving = m_MinProbability;
475 double dDenominatorJoint = m_MinProbability;
476 double dDenominatorFixed = m_MinProbability;
477 double dSumFixed = m_MinProbability;
480 for(asamples=0; asamples<(
unsigned int)numsamplesA; asamples++)
482 double valueFixed = ( fixedSamplesB[bsamples] - fixedSamplesA[asamples] )
483 / m_FixedImageStandardDeviation;
484 valueFixed = vcl_exp(-0.5*valueFixed*valueFixed);
486 double valueMoving = ( movingSamplesB[bsamples] - movingSamplesA[asamples] )
487 / m_MovingImageStandardDeviation;
488 valueMoving = vcl_exp(-0.5*valueMoving*valueMoving);
490 dDenominatorMoving += valueMoving;
491 dDenominatorFixed += valueFixed;
492 dSumFixed += valueFixed;
496 dDenominatorJoint += valueMoving * valueFixed;
499 dLogSumFixed -= vcl_log(dSumFixed );
500 dLogSumMoving -= vcl_log(dDenominatorMoving );
501 dLogSumJoint -= vcl_log(dDenominatorJoint );
504 for(asamples=0; asamples<(
unsigned int)numsamplesA; asamples++)
506 double valueFixed = ( fixedSamplesB[bsamples] - fixedSamplesA[asamples] )
507 / m_FixedImageStandardDeviation;
508 valueFixed = vcl_exp(-0.5*valueFixed*valueFixed);
510 double valueMoving = ( movingSamplesB[bsamples] - movingSamplesA[asamples] )
511 / m_MovingImageStandardDeviation;
512 valueMoving = vcl_exp(-0.5*valueMoving*valueMoving);
513 const double weightFixed = valueFixed / dDenominatorFixed;
515 const double weightJoint = valueMoving * valueFixed / dDenominatorJoint;
518 double weight = ( weightFixed - weightJoint );
519 weight *= ( fixedSamplesB[bsamples] - fixedSamplesA[asamples] );
523 for (
unsigned int i=0; i<ImageDimension;i++)
525 derivative[i] += ( fixedGradientsB[bsamples][i] - fixedGradientsA[asamples][i] ) * weight;
530 const double threshold = -0.1 * nsamp * vcl_log(m_MinProbability );
531 if( dLogSumMoving > threshold || dLogSumFixed > threshold ||
532 dLogSumJoint > threshold )
540 value = dLogSumFixed + dLogSumMoving - dLogSumJoint;
542 value += vcl_log(nsamp );
544 m_MetricTotal += value;
545 this->m_Energy += value;
548 derivative /= vnl_math_sqr( m_FixedImageStandardDeviation );
550 double updatenorm=0.0;
551 for (
unsigned int tt=0; tt<ImageDimension; tt++)
553 updatenorm += derivative[tt]*derivative[tt];
555 updatenorm=vcl_sqrt(updatenorm);
557 if (updatenorm > 1.e-20 && this->GetNormalizeGradient())
559 derivative=derivative/updatenorm;
562 return derivative*this->GetGradientStep();