Search code examples
pythonpython-2.7numpyaffinetransformscikit-image

Piecewise Affine Transform+warp output looks strange


I have an image I am trying to warp using skimage.PiecewiseAffineTransform and skimage.warp. I have a set of control points (true) mapped to a new set of control points (ideal) but the warp is not returning what I expect.

In this example, I have a simple gradient of wavelengths I am trying to 'straighten out' into columns. (You might ask why I am finding contours and interpolating but that is because I am actually applying this code to a more complex use case. I just wanted to reproduce all the code for this simple example which results in the same strange output.)

Any reason why my output image just has the input image warped into a square and inset? I'm using Python 2.7.12 and matplotlib 1.5.1. Here is the code.

import matplotlib.pyplot as plt
import numpy as np
from skimage import measure, transform

true = np.array([range(i,i+10) for i in range(20)])
ideal = np.array([range(10)]*20)

# Find contours of ideal and true images and create list of control points for warp
true_control_pts = []
ideal_control_pts = []

for lam in ideal[0]:
    try:
        # Get the isowavelength contour in the true and ideal images
        tc = measure.find_contours(true, lam)[0]
        ic = measure.find_contours(ideal, lam)[0]
        nc = np.ones(ic.shape)

        # Use the y coordinates of the ideal contour
        nc[:, 0] = ic[:, 0]

        # Interpolate true contour onto ideal contour y axis so there are the same number of points
        nc[:, 1] = np.interp(ic[:, 0], tc[:, 0], tc[:, 1])

        # Add the control points to the appropriate list
        true_control_pts.append(nc.tolist())
        ideal_control_pts.append(ic.tolist())

    except (IndexError,AttributeError):
        pass

true_control_pts = np.array(true_control_pts)
ideal_control_pts = np.array(ideal_control_pts)

length = len(true_control_pts.flatten())/2
true_control_pts = true_control_pts.reshape(length,2)
ideal_control_pts = ideal_control_pts.reshape(length,2)

# Plot the original image
image = np.array([range(i,i+10) for i in range(20)]).astype(np.int32)
plt.figure()
plt.imshow(image, origin='lower', interpolation='none')
plt.title('Input image')

# Warp the actual image given the transformation between the true and ideal wavelength maps
tform = transform.PiecewiseAffineTransform()
tform.estimate(true_control_pts, ideal_control_pts)
out = transform.warp(image, tform)

# Plot the warped image!
fig, ax = plt.subplots()
ax.imshow(out, origin='lower', interpolation='none')
plt.title('Should be parallel lines')

The output for this looks like:

enter image description here

Any assistance would be greatly appreciated!


Solution

  • I'll give you what I have. I don't think I'm going to be able to nail this down with the test data given. Mapping a 45 degree angle to a straight line in such a small image means that there's a lot of motion that needs to happen, and not a lot of data to base it on. I did find a few specific errors that are fixed in the code below, marked /* */ (since that's not something you usually see in a Python file, that marker should stand out :) ).

    Please try this code on your real data and let me know if it works! With this input dataset, there are some nonzero outputs.

    The big issues were:

    • interp data needs to be sorted
    • many of the control points were nonsensical since there isn't enough data in this dataset (e.g., mapping large spans of a column to a single point)
    • range of the values in the image was off (so your original output was actually all almost 0 - the color patches were values around 1e-9).

    The most important thing I added, for the sake of your future coding, is a 3D plot showing how the "true" control points map to the "ideal" control points. That gives you a debug tool to show you whether your control-point mapping is as you expect. That plot is what led me to the interp problem.

    By the way, please use names like "before" and "after" instead of ideal and true :) . Trying to remember which is which tripped me up at least once.

    First try

    import pdb
    import matplotlib.pyplot as plt
    import numpy as np
    from skimage import measure, transform, img_as_float
    
    from mpl_toolkits.mplot3d import Axes3D # /**/
    
    #/**/
    # From https://stackoverflow.com/a/14491059/2877364 by
    # https://stackoverflow.com/users/1355221/dansalmo
    def flatten_list(L):
        for item in L:
            try:
                for i in flatten_list(item): yield i
            except TypeError:
                yield item
    #end flatten_list
    
    true_input = np.array([range(i,i+10) for i in range(20)])  # /** != True **/
    ideal = np.array([range(10)]*20)
    
    #pdb.set_trace()
    # Find contours of ideal and true_input images and create list of control points for warp
    true_control_pts = []
    ideal_control_pts = []
    OLD_true=[]     # /**/ for debugging
    OLD_ideal=[]
    
    for lam in [x+0.5 for x in ideal[0]]:   # I tried the 0.5 offset just to see,
        try:                                # but it didn't make much difference
    
            # Get the isowavelength contour in the true_input and ideal images
            tc = measure.find_contours(true_input, lam)[0]
            ic = measure.find_contours(ideal, lam)[0]
            nc = np.zeros(ic.shape) # /** don't need ones() **/
    
            # Use the y /** X? **/ coordinates of the ideal contour
            nc[:, 0] = ic[:, 0]
    
            # Interpolate true contour onto ideal contour y axis so there are the same number of points
    
            # /** Have to sort first - https://docs.scipy.org/doc/numpy/reference/generated/numpy.interp.html#numpy-interp **/
            tc_sorted = tc[tc[:,0].argsort()]
                # /** Thanks to https://stackoverflow.com/a/2828121/2877364 by
                # https://stackoverflow.com/users/208339/steve-tjoa **/
    
            nc[:, 1] = np.interp(ic[:, 0], tc_sorted[:, 0], tc_sorted[:, 1],
                left=np.nan, right=np.nan)
                # /** nan: If the interpolation is out of range, we're not getting
                #     useful data.  Therefore, flag it with a nan. **/
    
            # /** Filter out the NaNs **/
            # Thanks to https://stackoverflow.com/a/11453235/2877364 by
            # https://stackoverflow.com/users/449449/eumiro
            #pdb.set_trace()
            indices = ~np.isnan(nc).any(axis=1)
            nc_nonan = nc[indices]
            ic_nonan = ic[indices]
    
            # Add the control points to the appropriate list.
            # /** Flattening here since otherwise I wound up with dtype=object
            #     in the numpy arrays. **/
            true_control_pts.append(nc_nonan.flatten().tolist())
            ideal_control_pts.append(ic_nonan.flatten().tolist())
    
            OLD_true.append(nc)     # /** for debug **/
            OLD_ideal.append(ic)
    
        except (IndexError,AttributeError):
            pass
    
    #pdb.set_trace()
    # /** Make vectors of all the control points. **/
    true_flat = list(flatten_list(true_control_pts))
    ideal_flat = list(flatten_list(ideal_control_pts))
    true_control_pts = np.array(true_flat)
    ideal_control_pts = np.array(ideal_flat)
    
    # Make the vectors 2d
    length = len(true_control_pts)/2
    true_control_pts = true_control_pts.reshape(length,2)
    ideal_control_pts = ideal_control_pts.reshape(length,2)
    
    #pdb.set_trace()
    
    # Plot the original image
    image = np.array([range(i,i+10) for i in range(20)]) / 30.0 # /**.astype(np.int32)**/
        # /** You don't want int32 images!  See
        #     http://scikit-image.org/docs/dev/user_guide/data_types.html .
        #     Manually rescale the image to [0.0,1.0] by dividing by 30. **/
    image_float = img_as_float(image) #/** make sure skimage is happy */ 
    fig = plt.figure()
    plt.imshow(image_float, origin='lower', interpolation='none')
    plt.title('Input image')
    fig.show()  # /** I needed this on my test system **/
    
    # Warp the actual image given the transformation between the true and ideal wavelength maps
    tform = transform.PiecewiseAffineTransform()
    tform.estimate(true_control_pts, ideal_control_pts)
    out = transform.warp(image, tform)
        # /** since we started with float, and this is float, too, the two are
        #     comparable. **/
    
    pdb.set_trace()
    
    # Plot the warped image!
    fig, ax = plt.subplots()
    ax.imshow(out, origin='lower', interpolation='none')    # /**note: float**/
    plt.title('Should be parallel lines')
    fig.show()
    
    # /** Show the control points.
    #     The z=0 plane will be the "true" control points (before), and the
    #     z=1 plane will be the "ideal" control points (after). **/
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    fig.show()
    
    for rowidx in range(length):
        ax.plot([true_control_pts[rowidx,0], ideal_control_pts[rowidx,0]],
                [true_control_pts[rowidx,1], ideal_control_pts[rowidx,1]],
                [0,1])
    
    input() # /** because I was running from the command line **/
    

    Second try

    Getting closer:

    enter image description here

    And here's a view of the control-point mapping looking more promising:

    enter image description here

    You can see that it's trying to whirl the image a bit, which is what I would expect from this dataset.

    Code

    import pdb
    import matplotlib.pyplot as plt
    import numpy as np
    from skimage import measure, transform, img_as_float
    
    from mpl_toolkits.mplot3d import Axes3D # /**/
    
    #/**/
    # From https://stackoverflow.com/a/14491059/2877364 by
    # https://stackoverflow.com/users/1355221/dansalmo
    def flatten_list(L):
        for item in L:
            try:
                for i in flatten_list(item): yield i
            except TypeError:
                yield item
    #end flatten_list
    
    #/**/
    # Modified from https://stackoverflow.com/a/19122075/2877364 by
    # https://stackoverflow.com/users/2588210/christian-k
    def equispace(data, npts):
        x,y = data.T
        xd =np.diff(x)
        yd = np.diff(y)
        dist = np.sqrt(xd**2+yd**2)
        u = np.cumsum(dist)
        u = np.hstack([[0],u])
    
        t = np.linspace(0,u.max(),npts)
        xn = np.interp(t, u, x)
        yn = np.interp(t, u, y)
        return np.column_stack((xn, yn))
    
    true_input = np.array([range(i,i+10) for i in range(20)])  # /** != True **/
    ideal = np.array([range(10)]*20)
    
    #pdb.set_trace()
    # Find contours of ideal and true_input images and create list of control points for warp
    true_control_pts = []
    ideal_control_pts = []
    OLD_true=[]     # /**/ for debugging
    OLD_ideal=[]
    
    for lam in [x+0.5 for x in ideal[0]]:   # I tried the 0.5 offset just to see,
        try:                                # but it didn't make much difference
    
            # Get the isowavelength contour in the true_input and ideal images
            tc = measure.find_contours(true_input, lam)[0]
                # /** So this might not have very many numbers in it. **/
            ic = measure.find_contours(ideal, lam)[0]
                # /** CAUTION: this is assuming the contours are going the same
                #       direction.  If not, you'll need to make it so. **/
            #nc = np.zeros(ic.shape) # /** don't need ones() **/
    
            # /** We just want to find points on _tc_ to match _ic_.  That's
            #       interpolation _within_ a curve. **/
            #pdb.set_trace()
            nc_by_t = equispace(tc,ic.shape[0])
            ic_by_t = equispace(ic,ic.shape[0])
    
    
            ### /** Not this **/
            ## Use the y /** X? **/ coordinates of the ideal contour
            #nc[:, 0] = ic[:, 0]
            #
            ## Interpolate true contour onto ideal contour y axis so there are the same number of points
            #
            ## /** Have to sort first - https://docs.scipy.org/doc/numpy/reference/generated/numpy.interp.html#numpy-interp **/
            #tc_sorted = tc[tc[:,0].argsort()]
            #    # /** Thanks to https://stackoverflow.com/a/2828121/2877364 by
            #    # https://stackoverflow.com/users/208339/steve-tjoa **/
            #
            #nc[:, 1] = np.interp(ic[:, 0], tc_sorted[:, 0], tc_sorted[:, 1],
            #    left=np.nan, right=np.nan)
            #    # /** nan: If the interpolation is out of range, we're not getting
            #    #     useful data.  Therefore, flag it with a nan. **/
    
            # /** Filter out the NaNs **/
            # Thanks to https://stackoverflow.com/a/11453235/2877364 by
            # https://stackoverflow.com/users/449449/eumiro
            #pdb.set_trace()
            #indices = ~np.isnan(nc).any(axis=1)
            #nc_nonan = nc[indices]
            #ic_nonan = ic[indices]
            #
    
            # Add the control points to the appropriate list.
            ## /** Flattening here since otherwise I wound up with dtype=object
            ##     in the numpy arrays. **/
            #true_control_pts.append(nc_nonan.flatten().tolist())
            #ideal_control_pts.append(ic_nonan.flatten().tolist())
    
            #OLD_true.append(nc)     # /** for debug **/
            #OLD_ideal.append(ic)
    
            true_control_pts.append(nc_by_t)
            ideal_control_pts.append(ic_by_t)
    
        except (IndexError,AttributeError):
            pass
    
    pdb.set_trace()
    # /** Make vectors of all the control points. **/
    true_flat = list(flatten_list(true_control_pts))
    ideal_flat = list(flatten_list(ideal_control_pts))
    true_control_pts = np.array(true_flat)
    ideal_control_pts = np.array(ideal_flat)
    
    # Make the vectors 2d
    length = len(true_control_pts)/2
    true_control_pts = true_control_pts.reshape(length,2)
    ideal_control_pts = ideal_control_pts.reshape(length,2)
    
    #pdb.set_trace()
    
    # Plot the original image
    image = np.array([range(i,i+10) for i in range(20)]) / 30.0 # /**.astype(np.int32)**/
        # /** You don't want int32 images!  See
        #     http://scikit-image.org/docs/dev/user_guide/data_types.html .
        #     Manually rescale the image to [0.0,1.0] by dividing by 30. **/
    image_float = img_as_float(image) #/** make sure skimage is happy */ 
    fig = plt.figure()
    plt.imshow(image_float, origin='lower', interpolation='none')
    plt.title('Input image')
    fig.show()  # /** I needed this on my test system **/
    
    # Warp the actual image given the transformation between the true and ideal wavelength maps
    tform = transform.PiecewiseAffineTransform()
    tform.estimate(true_control_pts, ideal_control_pts)
    out = transform.warp(image, tform)
        # /** since we started with float, and this is float, too, the two are
        #     comparable. **/
    
    pdb.set_trace()
    
    # Plot the warped image!
    fig, ax = plt.subplots()
    ax.imshow(out, origin='lower', interpolation='none')    # /**note: float**/
    plt.title('Should be parallel lines')
    fig.show()
    
    # /** Show the control points.
    #     The z=0 plane will be the "true" control points (before), and the
    #     z=1 plane will be the "ideal" control points (after). **/
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    fig.show()
    
    for rowidx in range(length):
        ax.plot([true_control_pts[rowidx,0], ideal_control_pts[rowidx,0]],
                [true_control_pts[rowidx,1], ideal_control_pts[rowidx,1]],
                [0,1])
    
    input() # /** because I was running from the command line **/
    

    Explanation

    Think about the interpolation differently: the X and Y coordinates aren't actually what you want to map, are they? I think what you want to map is distance along the contour, so that the diagonal contour gets stretched out to be a vertical contour. That is what these lines do:

    nc_by_t = equispace(tc,ic.shape[0])
    ic_by_t = equispace(ic,ic.shape[0])
    

    We make ic.shape[0] number of equally-spaced points along each contour, and then map those points to each other. equispace is modified from this answer. So this stretches the shorter contour to fit the longer, whichever way that is, with the number of points defined by the contour length in ic. Actually, you can use any number of points with this technique; I just tested 100 points and got substantially the same result.

    Again, test this on your actual data. Think about what exactly your reference is for interpolation. Is it really X or Y coordinate? Is it distance along the contour? Percentage of the contour (as above)?

    OK, one more idea

    For this particular test case, would more data help? Yes!

    enter image description here

    I used larger images to determine the control points, then mapped a smaller image in the center of that region.

    Code

    (which is a total mess by this point - see =========== markers)

    import pdb
    import matplotlib.pyplot as plt
    import numpy as np
    from skimage import measure, transform, img_as_float
    
    from mpl_toolkits.mplot3d import Axes3D # /**/
    
    #/**/
    def test_figure(data,title):
        image_float = img_as_float(data) #/** make sure skimage is happy */ 
        fig = plt.figure()
        plt.imshow(image_float, origin='lower', interpolation='none')
        plt.title(title)
        fig.show()
    
    #/**/
    # From https://stackoverflow.com/a/14491059/2877364 by
    # https://stackoverflow.com/users/1355221/dansalmo
    def flatten_list(L):
        for item in L:
            try:
                for i in flatten_list(item): yield i
            except TypeError:
                yield item
    #end flatten_list
    
    #/**/
    # Modified from https://stackoverflow.com/a/19122075/2877364 by
    # https://stackoverflow.com/users/2588210/christian-k
    def equispace(data, npts):
        x,y = data.T
        xd =np.diff(x)
        yd = np.diff(y)
        dist = np.sqrt(xd**2+yd**2)
        u = np.cumsum(dist)
        u = np.hstack([[0],u])
    
        t = np.linspace(0,u.max(),npts)
        xn = np.interp(t, u, x)
        yn = np.interp(t, u, y)
        return np.column_stack((xn, yn))
    
    # ======================  BIGGER
    true_input = np.array([range(i,i+20) for i in range(30)])  # /** != True **/
    ideal = np.array([range(20)]*30)
    # ======================    
    test_figure(true_input / 50.0, 'true_input')
    test_figure(ideal / 20.0, 'ideal')
    
    #pdb.set_trace()
    # Find contours of ideal and true_input images and create list of control points for warp
    true_control_pts = []
    ideal_control_pts = []
    OLD_true=[]     # /**/ for debugging
    OLD_ideal=[]
    
    for lam in [x+0.5 for x in ideal[0]]:   # I tried the 0.5 offset just to see,
        try:                                # but it didn't make much difference
    
            # Get the isowavelength contour in the true_input and ideal images
            tc = measure.find_contours(true_input, lam)[0]
                # /** So this might not have very many numbers in it. **/
            ic = measure.find_contours(ideal, lam)[0]
                # /** CAUTION: this is assuming the contours are going the same
                #       direction.  If not, you'll need to make it so. **/
            #nc = np.zeros(ic.shape) # /** don't need ones() **/
    
            # /** We just want to find points on _tc_ to match _ic_.  That's
            #       interpolation _within_ a curve. **/
            #pdb.set_trace()
            nc_by_t = equispace(tc,ic.shape[0])
            ic_by_t = equispace(ic,ic.shape[0])
    
    
            ### /** Not this **/
            ## Use the y /** X? **/ coordinates of the ideal contour
            #nc[:, 0] = ic[:, 0]
            #
            ## Interpolate true contour onto ideal contour y axis so there are the same number of points
            #
            ## /** Have to sort first - https://docs.scipy.org/doc/numpy/reference/generated/numpy.interp.html#numpy-interp **/
            #tc_sorted = tc[tc[:,0].argsort()]
            #    # /** Thanks to https://stackoverflow.com/a/2828121/2877364 by
            #    # https://stackoverflow.com/users/208339/steve-tjoa **/
            #
            #nc[:, 1] = np.interp(ic[:, 0], tc_sorted[:, 0], tc_sorted[:, 1],
            #    left=np.nan, right=np.nan)
            #    # /** nan: If the interpolation is out of range, we're not getting
            #    #     useful data.  Therefore, flag it with a nan. **/
    
            # /** Filter out the NaNs **/
            # Thanks to https://stackoverflow.com/a/11453235/2877364 by
            # https://stackoverflow.com/users/449449/eumiro
            #pdb.set_trace()
            #indices = ~np.isnan(nc).any(axis=1)
            #nc_nonan = nc[indices]
            #ic_nonan = ic[indices]
            #
    
            # Add the control points to the appropriate list.
            ## /** Flattening here since otherwise I wound up with dtype=object
            ##     in the numpy arrays. **/
            #true_control_pts.append(nc_nonan.flatten().tolist())
            #ideal_control_pts.append(ic_nonan.flatten().tolist())
    
            #OLD_true.append(nc)     # /** for debug **/
            #OLD_ideal.append(ic)
    
            true_control_pts.append(nc_by_t)
            ideal_control_pts.append(ic_by_t)
    
        except (IndexError,AttributeError):
            pass
    
    #pdb.set_trace()
    # /** Make vectors of all the control points. **/
    true_flat = list(flatten_list(true_control_pts))
    ideal_flat = list(flatten_list(ideal_control_pts))
    true_control_pts = np.array(true_flat)
    ideal_control_pts = np.array(ideal_flat)
    
    # Make the vectors 2d
    length = len(true_control_pts)/2
    true_control_pts = true_control_pts.reshape(length,2)
    ideal_control_pts = ideal_control_pts.reshape(length,2)
    
    #pdb.set_trace()
    
    # Plot the original image
    image = np.array([range(i,i+10) for i in range(20)]) / 30.0 # /**.astype(np.int32)**/
        # /** You don't want int32 images!  See
        #     http://scikit-image.org/docs/dev/user_guide/data_types.html .
        #     Manually rescale the image to [0.0,1.0] by dividing by 30. **/
    
    # ======================    
    # /** Pad from 10x20 to 20x30 just for grins **/
    #pdb.set_trace()
    image = np.concatenate( (np.zeros((20,5)),image,np.zeros((20,5))), 1)
        # now it's 20x20
    image = np.concatenate( (np.zeros((5,20)),image,np.zeros((5,20))), 0)
    # ======================    
    
    #Plot it
    image_float = img_as_float(image) #/** make sure skimage is happy */ 
    fig = plt.figure()
    plt.imshow(image_float, origin='lower', interpolation='none')
    plt.title('Input image')
    fig.show()  # /** I needed this on my test system **/
    
    # Warp the actual image given the transformation between the true and ideal wavelength maps
    tform = transform.PiecewiseAffineTransform()
    tform.estimate(true_control_pts, ideal_control_pts)
    out = transform.warp(image, tform)
        # /** since we started with float, and this is float, too, the two are
        #     comparable. **/
    
    pdb.set_trace()
    
    # Plot the warped image!
    fig, ax = plt.subplots()
    ax.imshow(out, origin='lower', interpolation='none')    # /**note: float**/
    plt.title('Should be parallel lines')
    fig.show()
    
    # /** Show the control points.
    #     The z=0 plane will be the "true" control points (before), and the
    #     z=1 plane will be the "ideal" control points (after). **/
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    fig.show()
    
    for rowidx in range(length):
        ax.plot([true_control_pts[rowidx,0], ideal_control_pts[rowidx,0]],
                [true_control_pts[rowidx,1], ideal_control_pts[rowidx,1]],
                [0,1])
    
    input() # /** because I was running from the command line **/