Search code examples
pythonvtkitk

how to match vtk polydata and itk transforms


I need to transform (rotation for now) an itk image and vtk polydata using the same transform matrix but I am having trouble.

All the code and test data is here: https://github.com/jmerkow/vtk_itk_rotate

Here are the relavent parts:

import SimpleITK as sitk
import vtk
import numpy as np
def rotate_img(img, rotation_center=None, theta_x=0,theta_y=0, theta_z=0, translation=(0,0,0), interp=sitk.sitkLinear, pixel_type=None, default_value=None):
    if not rotation_center:
        rotation_center = np.array(img.GetOrigin())+np.array(img.GetSpacing())*np.array(img.GetSize())/2
    if default_value is None:
        default_value = img.GetPixel(0,0,0)
    pixel_type = img.GetPixelIDValue()

    rigid_euler = sitk.Euler3DTransform(rotation_center, theta_x, theta_y, theta_z, translation)
    return sitk.Resample(img, img, rigid_euler, interp, default_value, pixel_type)

def rotate_polydata(pd, rotation_center, theta_x=0,theta_y=0, theta_z=0, translation=(0,0,0)):
    rigid_euler = sitk.Euler3DTransform(rotation_center, -theta_x, -theta_y, -theta_z, translation)
    matrix = np.zeros([4,4])
    old_matrix=np.array(rigid_euler.GetMatrix()).reshape(3,3)
    matrix[:3,:3] = old_matrix
    matrix[-1,-1] = 1

    # to rotate about a center we first need to translate
    transform_t = vtk.vtkTransform()
    transform_t.Translate(-rotation_center)
    transformer_t = vtk.vtkTransformPolyDataFilter()
    transformer_t.SetTransform(transform_t)
    transformer_t.SetInputData(pd)
    transformer_t.Update()

    transform = vtk.vtkTransform()
    transform.SetMatrix(matrix.ravel())

    transformer = vtk.vtkTransformPolyDataFilter()
    transformer.SetTransform(transform)
    transformer.SetInputConnection(transformer_t.GetOutputPort())
    transformer.Update()

    # translate back
    transform_t2 = vtk.vtkTransform()
    transform_t2.Translate(rotation_center)
    transformer_t2 = vtk.vtkTransformPolyDataFilter()
    transformer_t2.SetTransform(transform_t2)
    transformer_t2.SetInputConnection(transformer.GetOutputPort())
    transformer_t2.Update()

    return transformer_t2.GetOutputDataObject(0)

datafn = 'test.mha'
polydata_file = 'test.vtp'
reader = vtk.vtkXMLPolyDataReader()
reader.SetFileName(polydata_file)
reader.Update()
pd = reader.GetOutput()

img = sitk.ReadImage(datafn)
seg = pd_to_itk_image(pd, img)
rotation_center = np.array(img.GetOrigin())+np.array(img.GetSpacing())*np.array(img.GetSize())/2
thetas = [0, 50]
thetas = [0, 50]
for theta_x in thetas:
    for theta_y in thetas:
        for theta_z in thetas:
            theta_xr = theta_x/180.*np.pi
            theta_yr = theta_y/180.*np.pi
            theta_zr = theta_z/180.*np.pi
            img_rot=rotate_img(img, theta_z=theta_zr, theta_y=theta_yr, theta_x=theta_xr)
            seg_rot=rotate_img(seg, theta_z=theta_zr, theta_y=theta_yr, theta_x=theta_xr, interp=sitk.sitkNearestNeighbor, default_value=0)
            pd_rot = rotate_polydata(pd, rotation_center, theta_z=theta_zr, theta_y=theta_yr, theta_x=theta_xr)
            seg_pd_rot = pd_to_itk_image(pd_rot, img_rot)
            mse = ((sitk.GetArrayFromImage(seg_pd_rot)-sitk.GetArrayFromImage(seg_rot))**2.).mean()

            print theta_x, theta_y, theta_z, mse

#this outputs for this particular volume:
#0 0 0 mse: 0.0
#0 0 50 mse: 50.133369863 visually about the same
#0 50 0 mse: 25.2197787166 visually about the same
#0 50 50 mse: 863.588476181 visually totally different
#50 0 0 mse: 20.4021692276 visually about the same
#50 0 50 mse: 546.699844301 visually totally different
#50 50 0 mse: 662.337975204 visually totally different
#50 50 50 mse: 339.220945537 visually totally different

This code rotates a binary volume generated from the polydata, and performs the same rotation operation on the polydata then generates a binary volume from that. I would expect these two results to be (approximately) the same, however, what I am getting is two completely different rotations if I rotate around more than one axis. This is puzzling to me since I am taking the transformation matrix from one and applying it directly to the other.

How can I setup these transforms such that the two operations perform the same transformations? And why do we end up with different results?


Solution

  • Thank you to Dženan for pointing me in the right direction.

    The answer, in this case, was simple. VTK and ITK use different row/column major formats for their matrix multiplication. So the answer was simply to transpose the matrix before putting it into the vtkTransform.

    Here is the new function.

    def rotate_polydata(pd, rotation_center, theta_x=0,theta_y=0, theta_z=0):
        #I don't want to deal with translation
        translation=(0,0,0)
        rigid_euler = sitk.Euler3DTransform(rotation_center, theta_x, theta_y, theta_z, translation)
        matrix = np.zeros([4,4])
        old_matrix=np.array(rigid_euler.GetMatrix()).reshape(3,3)
        matrix[:3,:3] = old_matrix
        matrix[-1,-1] = 1
        #ITK and VTK use different orders.
        matrix= matrix.T
    
        # to rotate about a center we first need to translate
        transform_t = vtk.vtkTransform()
        transform_t.Translate(-rotation_center)
        transformer_t = vtk.vtkTransformPolyDataFilter()
        transformer_t.SetTransform(transform_t)
        transformer_t.SetInputData(pd)
        transformer_t.Update()
    
        transform = vtk.vtkTransform()
        transform.SetMatrix(matrix.ravel())
        transform.Translate(translation)
        transform.PostMultiply()
    
        transformer = vtk.vtkTransformPolyDataFilter()
        transformer.SetTransform(transform)
        transformer.SetInputConnection(transformer_t.GetOutputPort())
        transformer.Update()
    
        # translate back
        transform_t2 = vtk.vtkTransform()
        transform_t2.Translate(rotation_center)
        transformer_t2 = vtk.vtkTransformPolyDataFilter()
        transformer_t2.SetTransform(transform_t2)
        transformer_t2.SetInputConnection(transformer.GetOutputPort())
        transformer_t2.Update()
    
        return transformer_t2.GetOutputDataObject(0)