Image Registration Method Exhaustive
Overview
This script demonstrates the use of the Exhaustive optimizer in the ImageRegistrationMethod to estimate a good initial rotation position.
Because gradient descent base optimization can get stuck in local minima, a good initial transform is critical for reasonable results. Search a reasonable space on a grid with brute force may be a reliable way to get a starting location for further optimization.
The initial translation and center of rotation for the transform is initialized based on the first principle moments of the intensities of the image. Then in either 2D or 3D a Euler transform is used to exhaustively search a grid of the rotation space at a certain step size. The resulting transform is a reasonable guess where to start further registration.
Code
using System;
using itk.simple;
namespace itk.simple.examples
{
class IterationUpdate : Command
{
private ImageRegistrationMethod m_Method;
public IterationUpdate(ImageRegistrationMethod m)
{
m_Method = m;
}
public override void Execute()
{
if (m_Method.GetOptimizerIteration() == 0)
{
VectorDouble scales = m_Method.GetOptimizerScales();
Console.Write("Scales: [" + scales[0]);
for (int i = 1; i < scales.Count; i++)
{
Console.Write(", " + scales[i]);
}
Console.WriteLine("]");
}
VectorDouble pos = m_Method.GetOptimizerPosition();
Console.Write("{0,3} = {1,7:F5} : [{2:F5}",
m_Method.GetOptimizerIteration(),
m_Method.GetMetricValue(),
pos[0]);
for (int i = 1; i < pos.Count; i++)
{
Console.Write(", {0:F5}", pos[i]);
}
Console.WriteLine("]");
}
}
class ImageRegistrationMethodExhaustive
{
static void Main(string[] args)
{
if (args.Length < 3)
{
Console.WriteLine("Usage: ImageRegistrationMethodExhaustive <fixedImageFile> <movingImageFile> <outputTransformFile>");
return;
}
Image fixedImage = SimpleITK.ReadImage(args[0], PixelIDValueEnum.sitkFloat32);
Image movingImage = SimpleITK.ReadImage(args[1], PixelIDValueEnum.sitkFloat32);
ImageRegistrationMethod R = new ImageRegistrationMethod();
R.SetMetricAsMattesMutualInformation(50); // numberOfHistogramBins
int samplePerAxis = 12;
Transform tx = null;
if (fixedImage.GetDimension() == 2)
{
tx = new Euler2DTransform();
// Set the number of samples (radius) in each dimension, with a default step size of 1.0
VectorUInt32 exhaustiveSteps = new VectorUInt32(new uint[] { (uint)(samplePerAxis / 2), 0, 0 });
R.SetOptimizerAsExhaustive(exhaustiveSteps);
// Utilize the scale to set the step size for each dimension
VectorDouble scales = new VectorDouble(new double[] { 2.0 * Math.PI / samplePerAxis, 1.0, 1.0 });
R.SetOptimizerScales(scales);
}
else if (fixedImage.GetDimension() == 3)
{
tx = new Euler3DTransform();
VectorUInt32 exhaustiveSteps = new VectorUInt32(new uint[] {
(uint)(samplePerAxis / 2),
(uint)(samplePerAxis / 2),
(uint)(samplePerAxis / 4),
0, 0, 0
});
R.SetOptimizerAsExhaustive(exhaustiveSteps);
VectorDouble scales = new VectorDouble(new double[] {
2.0 * Math.PI / samplePerAxis,
2.0 * Math.PI / samplePerAxis,
2.0 * Math.PI / samplePerAxis,
1.0, 1.0, 1.0
});
R.SetOptimizerScales(scales);
}
// Initialize the transform with a translation and the center of rotation from the moments of intensity.
tx = SimpleITK.CenteredTransformInitializer(fixedImage, movingImage, tx);
R.SetInitialTransform(tx);
R.SetInterpolator(InterpolatorEnum.sitkLinear);
IterationUpdate cmd = new IterationUpdate(R);
R.AddCommand(EventEnum.sitkIterationEvent, cmd);
Transform outTx = R.Execute(fixedImage, movingImage);
Console.WriteLine("-------");
Console.WriteLine(outTx.ToString());
Console.WriteLine("Optimizer stop condition: " + R.GetOptimizerStopConditionDescription());
Console.WriteLine(" Iteration: " + R.GetOptimizerIteration());
Console.WriteLine(" Metric value: " + R.GetMetricValue());
SimpleITK.WriteTransform(outTx, args[2]);
if (Environment.GetEnvironmentVariable("SITK_NOSHOW") == null)
{
ResampleImageFilter resampler = new ResampleImageFilter();
resampler.SetReferenceImage(fixedImage);
resampler.SetInterpolator(InterpolatorEnum.sitkLinear);
resampler.SetDefaultPixelValue(1);
resampler.SetTransform(outTx);
Image output = resampler.Execute(movingImage);
Image simg1 = SimpleITK.Cast(SimpleITK.RescaleIntensity(fixedImage), PixelIDValueEnum.sitkUInt8);
Image simg2 = SimpleITK.Cast(SimpleITK.RescaleIntensity(output), PixelIDValueEnum.sitkUInt8);
Image cimg = SimpleITK.Compose(simg1, simg2, SimpleITK.Divide(SimpleITK.Add(simg1, simg2), 2));
SimpleITK.Show(cimg, "ImageRegistrationExhaustive Composition");
}
}
}
}
#include <SimpleITK.h>
#include <iostream>
#include <stdlib.h>
#include <iomanip>
#include <cmath>
#ifndef M_PI
# define M_PI 3.14159265358979323846
#endif
namespace sitk = itk::simple;
class IterationUpdate : public sitk::Command
{
public:
IterationUpdate(const sitk::ImageRegistrationMethod & m)
: m_Method(m)
{}
void
Execute() override
{
// use sitk's output operator for std::vector etc..
using sitk::operator<<;
if (m_Method.GetOptimizerIteration() == 0)
{
std::cout << "Scales: " << m_Method.GetOptimizerScales() << std::endl;
}
// stash the stream state
std::ios state(NULL);
state.copyfmt(std::cout);
std::cout << std::fixed << std::setfill(' ') << std::setprecision(5);
std::cout << std::setw(3) << m_Method.GetOptimizerIteration();
std::cout << " = " << std::setw(7) << m_Method.GetMetricValue();
std::cout << " : " << m_Method.GetOptimizerPosition() << std::endl;
std::cout.copyfmt(state);
}
private:
const sitk::ImageRegistrationMethod & m_Method;
};
int
main(int argc, char * argv[])
{
if (argc < 4)
{
std::cout << "Usage: " << argv[0] << " <fixedImageFile> <movingImageFile> <outputTransformFile>" << std::endl;
return 1;
}
sitk::Image fixedImage = sitk::ReadImage(argv[1], sitk::sitkFloat32);
sitk::Image movingImage = sitk::ReadImage(argv[2], sitk::sitkFloat32);
sitk::ImageRegistrationMethod R;
R.SetMetricAsMattesMutualInformation(50); // numberOfHistogramBins
int samplePerAxis = 12;
sitk::Transform tx;
if (fixedImage.GetDimension() == 2)
{
tx = sitk::Euler2DTransform();
// Set the number of samples (radius) in each dimension, with a default step size of 1.0
std::vector<unsigned int> exhaustiveSteps = { static_cast<unsigned int>(samplePerAxis / 2), 0, 0 };
R.SetOptimizerAsExhaustive(exhaustiveSteps);
// Utilize the scale to set the step size for each dimension
std::vector<double> scales = { 2.0 * M_PI / samplePerAxis, 1.0, 1.0 };
R.SetOptimizerScales(scales);
}
else if (fixedImage.GetDimension() == 3)
{
tx = sitk::Euler3DTransform();
std::vector<unsigned int> exhaustiveSteps = { static_cast<unsigned int>(samplePerAxis / 2),
static_cast<unsigned int>(samplePerAxis / 2),
static_cast<unsigned int>(samplePerAxis / 4),
0,
0,
0 };
R.SetOptimizerAsExhaustive(exhaustiveSteps);
std::vector<double> scales = {
2.0 * M_PI / samplePerAxis, 2.0 * M_PI / samplePerAxis, 2.0 * M_PI / samplePerAxis, 1.0, 1.0, 1.0
};
R.SetOptimizerScales(scales);
}
// Initialize the transform with a translation and the center of rotation from the moments of intensity.
tx = sitk::CenteredTransformInitializer(fixedImage, movingImage, tx);
R.SetInitialTransform(tx);
R.SetInterpolator(sitk::sitkLinear);
IterationUpdate cmd(R);
R.AddCommand(sitk::sitkIterationEvent, cmd);
sitk::Transform outTx = R.Execute(fixedImage, movingImage);
std::cout << "-------" << std::endl;
std::cout << outTx.ToString() << std::endl;
std::cout << "Optimizer stop condition: " << R.GetOptimizerStopConditionDescription() << std::endl;
std::cout << " Iteration: " << R.GetOptimizerIteration() << std::endl;
std::cout << " Metric value: " << R.GetMetricValue() << std::endl;
sitk::WriteTransform(outTx, argv[3]);
if (getenv("SITK_NOSHOW") == nullptr)
{
sitk::ResampleImageFilter resampler;
resampler.SetReferenceImage(fixedImage);
resampler.SetInterpolator(sitk::sitkLinear);
resampler.SetDefaultPixelValue(1);
resampler.SetTransform(outTx);
sitk::Image out = resampler.Execute(movingImage);
sitk::Image simg1 = sitk::Cast(sitk::RescaleIntensity(fixedImage), sitk::sitkUInt8);
sitk::Image simg2 = sitk::Cast(sitk::RescaleIntensity(out), sitk::sitkUInt8);
sitk::Image cimg = sitk::Compose(simg1, simg2, sitk::Divide(sitk::Add(simg1, simg2), 2));
sitk::Show(cimg, "ImageRegistrationExhaustive Composition");
}
return 0;
}
import org.itk.simple.*;
import java.text.DecimalFormat;
class IterationUpdate extends Command {
private ImageRegistrationMethod method;
public IterationUpdate(ImageRegistrationMethod m) {
method = m;
}
public void execute() {
if (method.getOptimizerIteration() == 0) {
VectorDouble scales = method.getOptimizerScales();
System.out.print("Scales: ["+scales.get(0));
for (int i = 1; i < scales.size(); i++) {
System.out.print(", " + scales.get(i));
}
System.out.println("]");
}
VectorDouble pos = method.getOptimizerPosition();
DecimalFormat df = new DecimalFormat("0.00000");
System.out.print(String.format("%3d = %7s : [",
method.getOptimizerIteration(),
df.format(method.getMetricValue())) + df.format(pos.get(0)));
for (int i = 1; i < pos.size(); i++) {
System.out.print(", " + df.format(pos.get(i)));
}
System.out.println("]");
}
}
public class ImageRegistrationMethodExhaustive {
public static void main(String[] args) throws Exception {
if (args.length < 3) {
System.out.println("Usage: ImageRegistrationMethodExhaustive <fixedImageFile> <movingImageFile> <outputTransformFile>");
System.exit(1);
}
Image fixedImage = SimpleITK.readImage(args[0], PixelIDValueEnum.sitkFloat32);
Image movingImage = SimpleITK.readImage(args[1], PixelIDValueEnum.sitkFloat32);
ImageRegistrationMethod R = new ImageRegistrationMethod();
R.setMetricAsMattesMutualInformation(50); // numberOfHistogramBins
long samplePerAxis = 12;
Transform tx = null;
if (fixedImage.getDimension() == 2) {
tx = new Euler2DTransform();
// Set the number of samples (radius) in each dimension, with a default step size of 1.0
VectorUInt32 exhaustiveSteps = new VectorUInt32();
exhaustiveSteps.add(samplePerAxis / 2);
exhaustiveSteps.add(0L);
exhaustiveSteps.add(0L);
R.setOptimizerAsExhaustive(exhaustiveSteps);
// Utilize the scale to set the step size for each dimension
VectorDouble scales = new VectorDouble();
scales.add(2.0 * Math.PI / samplePerAxis);
scales.add(1.0);
scales.add(1.0);
R.setOptimizerScales(scales);
} else if (fixedImage.getDimension() == 3) {
tx = new Euler3DTransform();
VectorUInt32 exhaustiveSteps = new VectorUInt32();
exhaustiveSteps.add(samplePerAxis / 2);
exhaustiveSteps.add(samplePerAxis / 2);
exhaustiveSteps.add(samplePerAxis / 4);
exhaustiveSteps.add(0L);
exhaustiveSteps.add(0L);
exhaustiveSteps.add(0L);
R.setOptimizerAsExhaustive(exhaustiveSteps);
VectorDouble scales = new VectorDouble();
scales.add(2.0 * Math.PI / samplePerAxis);
scales.add(2.0 * Math.PI / samplePerAxis);
scales.add(2.0 * Math.PI / samplePerAxis);
scales.add(1.0);
scales.add(1.0);
scales.add(1.0);
R.setOptimizerScales(scales);
}
// Initialize the transform with a translation and the center of rotation from the moments of intensity.
tx = SimpleITK.centeredTransformInitializer(fixedImage, movingImage, tx);
R.setInitialTransform(tx);
R.setInterpolator(InterpolatorEnum.sitkLinear);
IterationUpdate cmd = new IterationUpdate(R);
R.addCommand(EventEnum.sitkIterationEvent, cmd);
Transform outTx = R.execute(fixedImage, movingImage);
System.out.println("-------");
System.out.println(outTx.toString());
System.out.println("Optimizer stop condition: " + R.getOptimizerStopConditionDescription());
System.out.println(" Iteration: " + R.getOptimizerIteration());
System.out.println(" Metric value: " + R.getMetricValue());
SimpleITK.writeTransform(outTx, args[2]);
if (System.getenv("SITK_NOSHOW") == null) {
ResampleImageFilter resampler = new ResampleImageFilter();
resampler.setReferenceImage(fixedImage);
resampler.setInterpolator(InterpolatorEnum.sitkLinear);
resampler.setDefaultPixelValue(1);
resampler.setTransform(outTx);
Image output = resampler.execute(movingImage);
Image simg1 = SimpleITK.cast(SimpleITK.rescaleIntensity(fixedImage), PixelIDValueEnum.sitkUInt8);
Image simg2 = SimpleITK.cast(SimpleITK.rescaleIntensity(output), PixelIDValueEnum.sitkUInt8);
Image cimg = SimpleITK.compose(simg1, simg2, SimpleITK.divide(SimpleITK.add(simg1, simg2), 2.0));
SimpleITK.show(cimg, "ImageRegistrationExhaustive Composition");
}
}
}
#!/usr/bin/env python
import sys
import os
from math import pi
import SimpleITK as sitk
def command_iteration(method):
""" Callback invoked when the optimization has an iteration """
if method.GetOptimizerIteration() == 0:
print("Scales: ", method.GetOptimizerScales())
print(
f"{method.GetOptimizerIteration():3} "
+ f"= {method.GetMetricValue():7.5f} "
+ f": {method.GetOptimizerPosition()}"
)
if len(sys.argv) < 4:
print(
"Usage:",
sys.argv[0],
"<fixedImageFilter> <movingImageFile>",
"<outputTransformFile>",
)
sys.exit(1)
fixed = sitk.ReadImage(sys.argv[1], sitk.sitkFloat32)
moving = sitk.ReadImage(sys.argv[2], sitk.sitkFloat32)
R = sitk.ImageRegistrationMethod()
R.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
sample_per_axis = 12
tx = None
if fixed.GetDimension() == 2:
tx = sitk.Euler2DTransform()
# Set the number of samples (radius) in each dimension, with a
# default step size of 1.0
R.SetOptimizerAsExhaustive([sample_per_axis // 2, 0, 0])
# Utilize the scale to set the step size for each dimension
R.SetOptimizerScales([2.0 * pi / sample_per_axis, 1.0, 1.0])
elif fixed.GetDimension() == 3:
tx = sitk.Euler3DTransform()
R.SetOptimizerAsExhaustive(
[
sample_per_axis // 2,
sample_per_axis // 2,
sample_per_axis // 4,
0,
0,
0,
]
)
R.SetOptimizerScales(
[
2.0 * pi / sample_per_axis,
2.0 * pi / sample_per_axis,
2.0 * pi / sample_per_axis,
1.0,
1.0,
1.0,
]
)
# Initialize the transform with a translation and the center of
# rotation from the moments of intensity.
tx = sitk.CenteredTransformInitializer(fixed, moving, tx)
R.SetInitialTransform(tx)
R.SetInterpolator(sitk.sitkLinear)
R.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(R))
outTx = R.Execute(fixed, moving)
print("-------")
print(outTx)
print(f"Optimizer stop condition: {R.GetOptimizerStopConditionDescription()}")
print(f" Iteration: {R.GetOptimizerIteration()}")
print(f" Metric value: {R.GetMetricValue()}")
sitk.WriteTransform(outTx, sys.argv[3])
if "SITK_NOSHOW" not in os.environ:
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(fixed)
resampler.SetInterpolator(sitk.sitkLinear)
resampler.SetDefaultPixelValue(1)
resampler.SetTransform(outTx)
out = resampler.Execute(moving)
simg1 = sitk.Cast(sitk.RescaleIntensity(fixed), sitk.sitkUInt8)
simg2 = sitk.Cast(sitk.RescaleIntensity(out), sitk.sitkUInt8)
cimg = sitk.Compose(simg1, simg2, simg1 // 2.0 + simg2 // 2.0)
sitk.Show(cimg, "ImageRegistrationExhaustive Composition")
library(SimpleITK)
commandIteration <- function(method)
{
if (method$GetOptimizerIteration()==0) {
cat("Scales:", method$GetOptimizerScales(), "\n")
}
msg <- paste(method$GetOptimizerIteration(), "=",
method$GetMetricValue(), ":",
method$GetOptimizerPosition(), "\n" )
cat(msg)
}
args <- commandArgs( TRUE )
if (length(args) != 3) {
stop("3 arguments expected - fixedImageFilter, movingImageFile, outputTransformFile")
}
fixed <- ReadImage(args[[1]], 'sitkFloat32')
moving <- ReadImage(args[[2]], 'sitkFloat32')
R <- ImageRegistrationMethod()
R$SetMetricAsMattesMutualInformation(numberOfHistogramBins = 50)
sample_per_axis <- 12
if (fixed$GetDimension() == 2) {
tx <- Euler2DTransform()
# Set the number of samples (radius) in each dimension, with a
# default step size of 1.0
R$SetOptimizerAsExhaustive(c(sample_per_axis%/%2,0,0))
# Utilize the scale to set the step size for each dimension
R$SetOptimizerScales(c(2.0*pi/sample_per_axis,1.0,1.0))
} else if (fixed.GetDimension() == 3) {
tx <- Euler3DTransform()
R$SetOptimizerAsExhaustive(c(sample_per_axis%/%2,sample_per_axis%/%2,sample_per_axis%/%4,0,0,0))
R$SetOptimizerScales(c(2.0*pi/sample_per_axis,2.0*pi/sample_per_axis,2.0*pi/sample_per_axis,1.0,1.0,1.0))
}
# Initialize the transform with a translation and the center of
# rotation from the moments of intensity.
tx <- CenteredTransformInitializer(fixed, moving, tx)
R$SetInitialTransform(tx)
R$SetInterpolator('sitkLinear')
R$AddCommand( 'sitkIterationEvent', function() commandIteration(R) )
outTx <- R$Execute(fixed, moving)
cat("-------\n")
outTx
cat("Optimizer stop condition:", R$GetOptimizerStopConditionDescription(), '\n')
cat("Iteration:", R$GetOptimizerIteration(), '\n')
cat("Metric value:", R$GetMetricValue(), '\n')
WriteTransform(outTx, args[[3]])