Search code examples
pythonnumpymatrixvectorizationnumba

Vectorization of complicated matrix calculation in Python


I have a 3x3 matrix which I can calculate as follows

    e_matrix = np.array([m[i]*m[j]-n[i]*n[j] for i in range(3) for j in range(3)]) 

where m and n are both length-3 vectors. This works ok ✅

Now suppose that m and n are matrices of shape (K,3). I want to do an analogous calculation to the above and get an e_matrix of shape (3,3,K).

I realise I can just do a naive approach e.g.


e_matrix = np.zeros((3,3,K))

for k in range(K):
    e_matrix[:,:,k] = np.array([m[k,i]*m[k,j]-n[k,i]*n[k,j] for i in range(3) for j in range(3)]) 
   

Does a vectorized approach exist? Ideally one that works with Numba/JIT compiling (so no e.g. np.tensordot).


Solution

  • This is the "vectorized" variation using numba.

    import numpy as np
    import numba
    
    
    @numba.guvectorize("(k),(k)->(k,k)", nopython=True)
    def nb_vec_impl(m, n, out):
        for i in range(3):
            for j in range(3):
                out[i, j] = m[i] * m[j] - n[i] * n[j]
    
    
    def nb_vec(m, n):
        K = len(m)
        e_matrix = np.zeros((3, 3, K))  # Notice the position of K.
        # The matrix you want is (3,3,K), but this function accepts (K,3,3).
        # Instead of transposing the results, the same thing can be done by passing a transposed buffer.
        nb_vec_impl(m, n, e_matrix.T)
        return e_matrix
    

    This function is fast when K is very large (e.g., 100,000), but, unfortunately, very slow when K is small due to the large overhead.

    If performance is your priority, I would prefer a simple jit function.

    import numpy as np
    import numba
    
    
    @numba.njit
    def nb_jit(m, n):
        K = len(m)
        e_matrix = np.zeros((3, 3, K))
    
        for k in range(K):
            for i in range(3):
                for j in range(3):
                    e_matrix[i, j, k] = m[k, i] * m[k, j] - n[k, i] * n[k, j]
    
        return e_matrix
    

    Benchmark:

    import timeit
    
    import numba
    import numpy as np
    
    
    def native(m, n):
        K = len(m)
        e_matrix = np.zeros((3, 3, K))
    
        for k in range(K):
            e_matrix[:, :, k] = np.array([[m[k, i] * m[k, j] - n[k, i] * n[k, j] for i in range(3)] for j in range(3)])
    
        return e_matrix
    
    
    def broadcast(m, n):
        m, n = m.T, n.T
        return m[:, None] * m[None, :] - n[:, None] * n[None, :]
    
    
    @numba.njit
    def broadcast_nb(m, n):
        m, n = m.T, n.T
        return m[:, None] * m[None, :] - n[:, None] * n[None, :]
    
    
    @numba.njit
    def func2(m, n):
        m, n = m.T, n.T
        x, y = m.shape
        return m.reshape(x, 1, y) * m.reshape(1, x, y) - n.reshape(1, x, y) * n.reshape(x, 1, y)
    
    
    @numba.njit
    def nb_jit(m, n):
        K = len(m)
        e_matrix = np.zeros((3, 3, K))
    
        for k in range(K):
            for i in range(3):
                for j in range(3):
                    e_matrix[i, j, k] = m[k, i] * m[k, j] - n[k, i] * n[k, j]
    
        return e_matrix
    
    
    @numba.guvectorize("(k),(k)->(k,k)", nopython=True)
    def nb_vec_impl(m, n, out):
        for i in range(3):
            for j in range(3):
                out[i, j] = m[i] * m[j] - n[i] * n[j]
    
    
    def nb_vec(m, n):
        K = len(m)
        e_matrix = np.zeros((3, 3, K))
        nb_vec_impl(m, n, e_matrix.T)  # numba can accept transposed arrays.
        return e_matrix
    
    
    def main():
        rng = np.random.default_rng(0)
        K = 100_000
        # K = 100
        m = rng.random((K, 3))
        n = rng.random((K, 3))
    
        candidates = [native, broadcast, broadcast_nb, func2, nb_jit, nb_vec]
    
        expected = native(m, n)
        # expected = broadcast(m, n)
    
        for f in candidates:
            assert np.array_equal(f(m, n), expected), f"{f.__name__}"
            n_run = 10
            elapsed = timeit.timeit(lambda: f(m, n), number=n_run) / n_run
            print(f"{f.__name__}: {elapsed * 1000:.3f} ms")
    
    
    if __name__ == "__main__":
        main()
    

    Results (for K = 100_000):

    native: 778.374 ms
    broadcast: 13.581 ms
    broadcast_nb: 2.848 ms
    func2: 6.207 ms
    nb_jit: 1.998 ms
    nb_vec: 1.891 ms
    

    Results (for K = 100):

    native: 0.798 ms
    broadcast: 0.016 ms
    broadcast_nb: 0.004 ms
    func2: 0.007 ms
    nb_jit: 0.002 ms
    nb_vec: 0.127 ms