Search code examples
pythonscipylinear-algebra

simple scipy.spatial.transform.Rotation.align_vectors testing script doesn't return the correct rotation


So I tried to get the rotation to align two planes to each other once the centers are at the same place. I failed so I came up with this very simpel example, Two rectangles that are 90° rotated relative to each other, and wrote the following short script

FROM

0.0 0.0 0.0
1.0 0.0 0.0
0.0 2.0 0.0
1.0 2.0 0.0

TO

0.0 0.0 0.0
2.0 0.0 0.0
2.0 1.0 0.0
0.0 1.0 0.0
import numpy as np
import files
from scipy.spatial.transform import Rotation

from_ = np.genfromtxt(files.FROM)
to = np.genfromtxt(files.TO)

mean_from = np.mean(from_, axis=0)
mean_to = np.mean(to, axis=0)

centered_from = from_ - mean_from
centered_to = to - mean_to

rot,res = Rotation.align_vectors(centered_from,centered_to)
print(res)
sol = rot.as_quat()
print("({}f, {}f, {}f, {}f)".format(sol[0],sol[1],sol[2],sol[3]))

Optimal rotation is not uniquely or poorly defined for the given sets of vectors. rot,res = Rotation.align_vectors(centered_from,centered_to) 2.449489742783178 (0.0f, 0.0f, 0.0f, 1.0f)

Why would it fail at such a simple example? And even if I use a more complicated example with 12 points the result does not align the points properly. The points are rotated 90° in one direction and then 180° in another direction

(Complex Example) The points are rotated 90° in one direction and then 180° in another direction

FROM

-1.061441 1.029105 0.3137982
-1.058325 0.6978046 0.3261206
-0.5160019 1.044142 0.5801408
-0.5128856 0.7128415 0.5924632
-0.5150666 1.089168 0.5781018
-0.5124189 0.733911 0.5916736
0.09491175 1.098186 0.6951484
0.09755945 0.7429286 0.7087202
-0.4337781 0.6063982 0.2094956
-0.3675401 0.5803195 0.08044138
-0.1081934 0.6057502 0.3767354
-0.04195532 0.5796715 0.2476811

TO

0.6978049 -1.058325 -0.3261207
0.7128419 -0.5128859 -0.5924634
1.029105 -1.061442 -0.3137983
1.044142 -0.5160022 -0.580141
0.7339112 -0.512419 -0.5916737
0.7429286 0.09755942 -0.7087204
1.089168 -0.5150667 -0.5781019
1.098186 0.09491169 -0.6951486
0.5795034 -0.2047889 -0.1641395
0.5804876 -0.2047065 -0.1639829
0.6055822 -0.271027 -0.2931938
0.6065664 -0.2709446 -0.2930371

Thats everything I tried and I dont understand how to further pin down the problem


Solution

  • The function Rotation.align_vectors() assumes that the vectors in both matrices are in the same order. The points are not in the same order. You can see this by plotting them.

    Simple example from/to:

    from

    to

    In the from example, it starts at (0,0), moves to (1, 0), moves to (2, 0) diagonally, then moves to (1, 2). In the to example, it has no diagonal move. There is no rotation that can fix this - it requires both a rotation and a permutation.

    You'll find something similar if you plot your complex example.

    In order to fix this, I tried expanding the search, in order to search through every rotation of every permutation.

    Here is the code I used:

    import numpy as np
    from scipy.spatial.transform import Rotation
    import matplotlib.pyplot as plt
    import math
    import itertools
    from tqdm import tqdm
    
    
    def array_permutations(array):
        N = array.shape[0]
        for permute in itertools.permutations(range(N)):
            yield array[list(permute)], permute
    
    def find_best_rotation(source, target):
        best_res = None
        best_rot = None
        best_perm = None
        for source_permuted, perm in tqdm(array_permutations(source), total=math.factorial(source.shape[0])):
            rot, res = Rotation.align_vectors(source_permuted, target)
            if best_res is None or res < best_res:
                best_res = res
                best_rot = rot
                best_perm = perm
        return best_rot, best_res, best_perm
        
    
    limit = 8  # Only consider first N points of each dataset
    rot, res, perm = find_best_rotation(from_[:limit], to[:limit])
    print("({}f, {}f, {}f, {}f)".format(sol[0],sol[1],sol[2],sol[3]))
    
    from_x = rot.apply(from_)[:, 0]
    from_y = rot.apply(from_)[:, 1]
    plt.figure(figsize=(6,6))
    jitter = 0.005 # make points visible
    plt.scatter(from_x + jitter, from_y + jitter)
    plt.scatter(to[:, 0], to[:, 1])
    

    Here's the result this produces:

    result

    Some notes on this solution:

    • This is only searching permutations of the first 8 points to keep this tractable. The number of permutations goes up as the factorial, so this won't scale to large numbers of points. Searching all permutations of 12 points was estimated to take 14 hours.
    • Some of the points seem to have no equivalent. On the left, there are 4 blue points, but only 2 orange ones. Since the other points in the dataset seem well-matched, I think this is a problem with the input dataset.
    • There may be a way to solve this without searching all permutations. I also considered using scipy.optimize.minimize to optimize distance to the nearest point, but did not go this way because of local minimums. (e.g. in some cases, a point may need to move away from its nearest neighbor because that neighbor is not its true neighbor.)