Search code examples
pythonperformancenumpycross-productnumpy-einsum

cross products with einsums


I'm trying to compute the cross-products of many 3x1 vector pairs as fast as possible. This

n = 10000
a = np.random.rand(n, 3)
b = np.random.rand(n, 3)
numpy.cross(a, b)

gives the correct answer, but motivated by this answer to a similar question, I thought that einsum would get me somewhere. I found that both

eijk = np.zeros((3, 3, 3))
eijk[0, 1, 2] = eijk[1, 2, 0] = eijk[2, 0, 1] = 1
eijk[0, 2, 1] = eijk[2, 1, 0] = eijk[1, 0, 2] = -1

np.einsum('ijk,aj,ak->ai', eijk, a, b)
np.einsum('iak,ak->ai', np.einsum('ijk,aj->iak', eijk, a), b)

compute the cross product, but their performance is disappointing: Both methods perform much worse than np.cross:

%timeit np.cross(a, b)
1000 loops, best of 3: 628 µs per loop
%timeit np.einsum('ijk,aj,ak->ai', eijk, a, b)
100 loops, best of 3: 9.02 ms per loop
%timeit np.einsum('iak,ak->ai', np.einsum('ijk,aj->iak', eijk, a), b)
100 loops, best of 3: 10.6 ms per loop

Any ideas of how to improve the einsums?


Solution

  • The count of multiply operation of einsum() is more then cross(), and in the newest NumPy version, cross() doesn't create many temporary arrays. So einsum() can't be faster than cross().

    Here is the old code of cross:

    x = a[1]*b[2] - a[2]*b[1]
    y = a[2]*b[0] - a[0]*b[2]
    z = a[0]*b[1] - a[1]*b[0]
    

    Here is the new code of cross:

    multiply(a1, b2, out=cp0)
    tmp = array(a2 * b1)
    cp0 -= tmp
    multiply(a2, b0, out=cp1)
    multiply(a0, b2, out=tmp)
    cp1 -= tmp
    multiply(a0, b1, out=cp2)
    multiply(a1, b0, out=tmp)
    cp2 -= tmp
    

    To speedup it, you need cython or numba.