Search code examples
numpypython-3.6parentheses

Execution time difference in matrix multiplication caused by parentheses


Given the two 1D numpy arrays a and b with

N = 100000
a = np.randn(N)
b = np.randn(N)

Why is there a considerable execution time difference between the following two expressions:

# expression 1
c = a @ a * b @ b

# expression 2
c = (a @ a) * (b @ b)

Using the %timeit magic of Jupyter Notebook I get the following results:

%timeit a @ a * b @ b

223 µs ± 6.97 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

and

%timeit (a @ a) * (b @ b)

17.4 µs ± 27.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


Solution

  • In both versions you do two dot products of length-N vectors. However, in addition the first solution performs N multiplications while the second solution only needs one.

    a @ a * b @ b is equivalent to ((a @ a) * b) @ b or

    aa = a @ a  # N multiplications and additions -> scalar
    aab = aa * b  # N multiplications -> vector
    aabb = aab @ b  # N multiplications and additions -> scalar
    

    (a @ a) * (b @ b) is equivalent to

    aa = a @ a  # N multiplications and additions -> scalar
    bb = b @ b  # N multiplications and additions -> scalar
    aabb = aa * bb  # 1 multiplication -> scalar
    

    The fact that matrix multiplication performance can depend on how to set the parentheses is well known. There exist algorithms to optimize matrix chain multiplication by exploiting this fact.

    Update: As I just learned, numpy has a function for optimizing multiple matrix multiplications: numpy.linalg.multidot