Search code examples
pythonarraysnumpynumbaarray-broadcasting

Numba typing error when multiplying a single vector with an array of vectors using broadcasting


I'm having a problem applying numba to a set of functions I'm trying to optimise for performance. All the functions work fine without numba but I get a compilation error when I try to use numba.

Here's the compilation error I'm struggling with:

Exception occurred:
Type: TypingError
Message: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
Cannot unify array(float64, 2d, C) and array(float64, 1d, C) for 'q1.2', defined at .\rotations.py (82)

File "rotations.py", line 82:
def quaternion_mult(q1, qa):
    <source elided>

    quat_result[:, 0] = (q1[:, 0] * q2[:, 0]) - (q1[:, 1] * q2[:, 1]) - (q1[:, 2] * q2[:, 2]) - (q1[:, 3] * q2[:, 3])
    ^

During: typing of assignment at .\rotations.py (82)

File "rotations.py", line 82:
def quaternion_mult(q1, qa):
    <source elided>

    quat_result[:, 0] = (q1[:, 0] * q2[:, 0]) - (q1[:, 1] * q2[:, 1]) - (q1[:, 2] * q2[:, 2]) - (q1[:, 3] * q2[:, 3])
    ^

During: resolving callee type: type(CPUDispatcher(<function quaternion_mult at 0x00000290EE6FE670>))
During: typing of call at .\rotations.py (102)

During: resolving callee type: type(CPUDispatcher(<function quaternion_mult at 0x00000290EE6FE670>))
During: typing of call at .\rotations.py (102)


File "rotations.py", line 102:
def quaternion_vect_mult(q1, vect_array):
    <source elided>

    temp = quaternion_mult(q1, q2)
    ^

and here's the full code of the corresponding functions:


@njit(cache=True)
def quaternion_conjugate_vect(q):
    """
    return the conjugate of a quaternion or an array of quaternions
    """
    return q * np.array([1, -1, -1, -1])


@njit(cache=True)
def quaternion_mult(q1, qa):
    """
    multiply an array of quaternions (Nx4) by a single quaternion.

    qa is always a (Nx4) array of quaternions np.ndarray
    q1 is always a single (1x4) quaternion np.ndarray

    """
    N = max(len(qa), len(q1))
    quat_result = np.zeros((N, 4), dtype=np.float64)

    if qa.ndim == 1:
        q2 = qa.copy().reshape((1, -1))
        # q2 = np.reshape(q1, (1,-1))
    else:
        q2 = qa

    if q1.ndim == 1:
        # q1 = q1.copy().reshape((1, -1))
        q1 = np.reshape(q1, (1, -1))

    quat_result[:, 0] = (q1[:, 0] * q2[:, 0]) - (q1[:, 1] * q2[:, 1]) - (q1[:, 2] * q2[:, 2]) - (q1[:, 3] * q2[:, 3])
    quat_result[:, 1] = (q1[:, 0] * q2[:, 1]) + (q1[:, 1] * q2[:, 0]) + (q1[:, 2] * q2[:, 3]) - (q1[:, 3] * q2[:, 2])
    quat_result[:, 2] = (q1[:, 0] * q2[:, 2]) + (q1[:, 2] * q2[:, 0]) + (q1[:, 3] * q2[:, 1]) - (q1[:, 1] * q2[:, 3])
    quat_result[:, 3] = (q1[:, 0] * q2[:, 3]) + (q1[:, 3] * q2[:, 0]) + (q1[:, 1] * q2[:, 2]) - (q1[:, 2] * q2[:, 1])

    return quat_result


@njit(cache=True)
def quaternion_vect_mult(q1, vect_array):
    """
    Multiplies an array of x,y,z coordinates by a single quaternion q1.
    """
    # q1 is the quaternion which the coordinates will be rotated by.

    # Add initial column of zeros to array
    # N = len(vect_array)
    q2 = np.zeros((len(vect_array), 4), dtype=np.float64)
    q2[:, 1::] = vect_array

    temp = quaternion_mult(q1, q2)
    result = quaternion_mult(temp, quaternion_conjugate_vect(q1))

    return result[:, 1::]

I don't understand the unification error as I'm broadcasting in the multiplication so the shape should be irrelevant? All arrays are of np.float64 so I've specified that as the type. The only difference is the shape but normal numpy broadcasting should work here as it does without numba. (I've added a load of extra brackets to make sure I was multiplying things correctly but they are not needed at all.)

I assume the problem has something to do with the creation of the np.zeros storage array, I've added this as previously I computed each column separately and then combined with np.stack.

My only other thought is that it relates to the if ... else... where I check if the single quaternion is of shape (1,4) instead of (,4).

I'm a bit stumped by this and other problems similar to this usually seem to have some type difference involved, like int and float or float32 and float64.

Any help is appreciated.

For clarity, here's an example that works without numba but fails with it enabled:

from numba import njit
import numpy as np

quat_single = np.random.random(,4)
coord_array = np.random.random([9,3])

Note: quat_single = np.random.random([1,4]) will work with `numba`


quaternion_vect_mult(quat_single, coord_array)
Out[18]: 
array([[ 0.12035005,  1.51894951,  0.26731225],
       [ 1.56889141,  0.56465019,  0.18818138],
       [ 0.58966646,  1.09653585, -0.19548354],
       [ 1.15044012,  1.56034916,  0.73943456],
       [ 0.83003034,  1.80861828,  0.02678796],
       [ 1.15572912,  0.54263501,  0.16206597],
       [ 1.34243762,  1.0802315 , -0.20735991],
       [ 1.5876305 ,  0.70017144,  0.80066164],
       [ 1.20734218,  1.2747372 , -0.47177605]])


Solution

  • With these lines:

        temp = quaternion_mult(q1, q2)
        result = quaternion_mult(temp, quaternion_conjugate_vect(q1))
    

    you're giving the quaternion_mult different parameter types each time, so numba is confused how to compile this function.

    Create the quaternion_mult separately for each parameter type/dimension you want to support, e.g.:

    @njit(cache=True)
    def quaternion_conjugate_vect(q):
        """
        return the conjugate of a quaternion or an array of quaternions
        """
        return q * np.array([1, -1, -1, -1])
    
    
    @njit(cache=True)
    def quaternion_mult1(q1, qa):
        """
        multiply an array of quaternions (Nx4) by a single quaternion.
    
        qa is always a (Nx4) array of quaternions np.ndarray
        q1 is always a single (1x4) quaternion np.ndarray
    
        """
        N = max(len(qa), len(q1))
        quat_result = np.zeros((N, 4), dtype=np.float64)
    
        # if qa.ndim == 1:
        #     q2 = qa.copy().reshape((1, -1))
        #     # q2 = np.reshape(q1, (1,-1))
        # else:
        #     q2 = qa
    
        q2 = qa
    
        # if q1.ndim == 1:
        #     # q1 = q1.copy().reshape((1, -1))
        #     q1 = np.reshape(q1, (1, -1))
    
        quat_result[:, 0] = (
            (q1[0] * q2[:, 0])
            - (q1[1] * q2[:, 1])
            - (q1[2] * q2[:, 2])
            - (q1[3] * q2[:, 3])
        )
        quat_result[:, 1] = (
            (q1[0] * q2[:, 1])
            + (q1[1] * q2[:, 0])
            + (q1[2] * q2[:, 3])
            - (q1[3] * q2[:, 2])
        )
        quat_result[:, 2] = (
            (q1[0] * q2[:, 2])
            + (q1[2] * q2[:, 0])
            + (q1[3] * q2[:, 1])
            - (q1[1] * q2[:, 3])
        )
        quat_result[:, 3] = (
            (q1[0] * q2[:, 3])
            + (q1[3] * q2[:, 0])
            + (q1[1] * q2[:, 2])
            - (q1[2] * q2[:, 1])
        )
    
        return quat_result
    
    
    @njit(cache=True)
    def quaternion_mult2(q1, qa):
        N = max(len(qa), len(q1))
        quat_result = np.zeros((N, 4), dtype=np.float64)
        q2 = qa.copy().reshape((1, -1))
    
        quat_result[:, 0] = (
            (q1[:, 0] * q2[:, 0])
            - (q1[:, 1] * q2[:, 1])
            - (q1[:, 2] * q2[:, 2])
            - (q1[:, 3] * q2[:, 3])
        )
        quat_result[:, 1] = (
            (q1[:, 0] * q2[:, 1])
            + (q1[:, 1] * q2[:, 0])
            + (q1[:, 2] * q2[:, 3])
            - (q1[:, 3] * q2[:, 2])
        )
        quat_result[:, 2] = (
            (q1[:, 0] * q2[:, 2])
            + (q1[:, 2] * q2[:, 0])
            + (q1[:, 3] * q2[:, 1])
            - (q1[:, 1] * q2[:, 3])
        )
        quat_result[:, 3] = (
            (q1[:, 0] * q2[:, 3])
            + (q1[:, 3] * q2[:, 0])
            + (q1[:, 1] * q2[:, 2])
            - (q1[:, 2] * q2[:, 1])
        )
    
        return quat_result
    
    
    @njit(cache=True)
    def quaternion_vect_mult(q1, vect_array):
        """
        Multiplies an array of x,y,z coordinates by a single quaternion q1.
        """
        # q1 is the quaternion which the coordinates will be rotated by.
    
        # Add initial column of zeros to array
        # N = len(vect_array)
        q2 = np.zeros((len(vect_array), 4), dtype=np.float64)
        q2[:, 1::] = vect_array
    
        temp = quaternion_mult1(q1, q2)
        result = quaternion_mult2(temp, quaternion_conjugate_vect(q1))
    
        return result[:, 1::]
    

    With:

    np.random.seed(42)
    quat_single = np.random.random(4)
    coord_array = np.random.random([9, 3])
    
    print(quaternion_vect_mult(quat_single, coord_array))
    

    This prints:

    [[ 0.26852132  0.20522199  0.28520316]
     [ 1.89120847  1.35797162  0.79965888]
     [ 2.322112   -0.39389235  0.76960471]
     [ 0.51270351  0.3143128   0.24153831]
     [ 1.2691966   0.32325645  0.60666047]
     [ 0.85615508  0.20021656  1.01254022]
     [ 1.15864463  0.39780013  0.3251974 ]
     [ 1.17341506  1.41237398  0.29654629]
     [ 1.15734464  1.16277993 -0.14839415]]
    

    According to my "benchmark" it should be ~30-40x faster than non-jitted version.