Search code examples
pythonsimpleitkimage-registration

Image registration with SimpleITK: how to get transformed image


I recently learned image registration using SimpleITK (sitk).

My task is to subtract two CT images that were acquired at two different time points.

First of all, I need to register these images. I did that using the codes from this notebook. Below is what I implemented from there (I blindly used those codes. I don't quite get sitk documentations). What I understand from the code below is that the final_transform contains the transformation needed to register these images. Can someone explain:

  1. what does this final_transform return?
  2. how can I use final_transform so that I can do pixel-wise operations between two CT images, i.e. image subtraction?
  3. how can I set one of the images as the reference image, and then transform the other so that it aligns with the reference?
reader = sitk.ImageSeriesReader()

dir_name_1 = "Directory of 1st CT image"
dir_name_2 = "Directory of 2nd CT image"

fixed_image = sitk.ReadImage(reader.GetGDCMSeriesFileNames(dir_name_1), sitk.sitkFloat32)
moving_image = sitk.ReadImage(reader.GetGDCMSeriesFileNames(dir_name_2), sitk.sitkFloat32)

# Metric evaluate method
# Dictionary with all the orientations we will try. We omit the identity (x=0, y=0, z=0) as we always use it. This
# set of rotations is arbitrary. For a complete grid coverage we would naively have 64 entries
# (0, pi/2, pi, 1.5pi for each angle), but we know better, there are only 24 unique rotation matrices defined by
# these parameter value combinations.
all_orientations = {
    "x=0, y=0, z=180": (0.0, 0.0, np.pi),
    "x=0, y=180, z=0": (0.0, np.pi, 0.0),
    "x=0, y=180, z=180": (0.0, np.pi, np.pi),
}

# Registration framework setup.
registration_method = sitk.ImageRegistrationMethod()
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)
registration_method.SetInterpolator(sitk.sitkLinear)

# Evaluate the similarity metric using the rotation parameter space sampling, translation remains the same for all.
initial_transform = sitk.Euler3DTransform(
    sitk.CenteredTransformInitializer(
        fixed_image,
        moving_image,
        sitk.Euler3DTransform(),
        sitk.CenteredTransformInitializerFilter.GEOMETRY,
    )
)
registration_method.SetInitialTransform(initial_transform, inPlace=False)
best_orientation = (0.0, 0.0, 0.0)
best_similarity_value = registration_method.MetricEvaluate(fixed_image, moving_image)

# Iterate over all other rotation parameter settings.
for key, orientation in all_orientations.items():
    initial_transform.SetRotation(*orientation)
    registration_method.SetInitialTransform(initial_transform)
    current_similarity_value = registration_method.MetricEvaluate(
        fixed_image, moving_image
    )
    if current_similarity_value < best_similarity_value:
        best_similarity_value = current_similarity_value
        best_orientation = orientation
print("best orientation is: " + str(best_orientation))


from multiprocessing.pool import ThreadPool
from functools import partial


# This function evaluates the metric value in a thread safe manner
def evaluate_metric(current_rotation, tx, f_image, m_image):
    registration_method = sitk.ImageRegistrationMethod()
    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)
    registration_method.SetInterpolator(sitk.sitkLinear)
    current_transform = sitk.Euler3DTransform(tx)
    current_transform.SetRotation(*current_rotation)
    registration_method.SetInitialTransform(current_transform)
    res = registration_method.MetricEvaluate(f_image, m_image)
    return res


p = ThreadPool(len(all_orientations) + 1)
orientations_list = [(0, 0, 0)] + list(all_orientations.values())
all_metric_values = p.map(
    partial(
        evaluate_metric, tx=initial_transform, f_image=fixed_image, m_image=moving_image
    ),
    orientations_list,
)
best_orientation = orientations_list[np.argmin(all_metric_values)]
print("best orientation is: " + str(best_orientation))


initial_transform.SetRotation(*best_orientation)
final_transform, _ = multires_registration(fixed_image, moving_image, initial_transform)

I am not sure how to proceed further with final_transform. How can I use it to subtract the two CT images?


Solution

  • The registration method is returning a transform that maps between the two spaces of the fixed and moving images. Presumable you want to resample the moving image so that it matches the fixed image. To do so you would use SimpleITK's Resample function.

    To resample the moving image, you would do something like this:

    resampled_moving = sitk.Resample(moving_image, fixed_image, final_transform)
    

    That would give you a version of the moving image that matches the orientation, resolution and spacing of the fixed image. Thus you could then subtract the two images.