Search code examples
pythonmultidimensional-arraycupynumpy-einsum

Cupy Code Optimization: How to speed up nested for loops


I would like to optimize the python code between the 2 perf_counter functions. By using cupy I already obtained substantial improvement compared to numpy. I was asking myself if there is some reordering or vectorization that I am missing. The main constraint is that I should not be building any tensor(ndarray) that is bigger than dim^4, since this is part of some memory optimization of a bigger project, for which a method that scales with dim^5 is already known and better performing.

import numpy as np
import cupy as cp
from time import perf_counter
dim = 60
T_final = cp.random.rand(dim,dim,dim,dim)
T_start = cp.random.rand(dim,dim,dim,dim)
U = cp.random.rand(dim,dim,dim)
cp.cuda.Stream.null.synchronize()
start = perf_counter()
for xf in range(0,dim):
    for yu in range(0,dim):
        B = cp.einsum("ij,kim->kmj",U[:,:,yu],T_start[xf,:,:,:])
        for xb in range(0,dim):
            C = cp.einsum("kmj,kjs->ms",B,T_start[:,xb,:,:])
            T_final[xf,xb,yu,:]=cp.einsum("msc,ms",U,C)
cp.cuda.Stream.null.synchronize()
print(perf_counter()-start)

Solution

  • I don't know cupy well, but I don't think this matters a lot, for the sake of simplicty I'll show how to simplify these nested einsum calls with numpy. So first it is a good idea to wrap everything into one einsum call and then do the optimization after.

    My recommendation would be working from the inner most loop outwards, and removing loops and adding the corresponding indices to einsum.

    So to eliminate the loop over xb, we can just use the corresponding array dimension instead, so from the original

    for xf in range(0,dim):
        for yu in range(0, dim):
            B = cp.einsum("ij,kim->kmj", U[:, :, yu], T_start[xf, :, :, :])
            for xb in range(0, dim):
                C = cp.einsum("kmj,kjs->ms", B, T_start[:, xb, :, :])
                T_final2[xf, xb, yu, :] = cp.einsum("msc,ms->c", U, C) #adding the -> c to be explicit
    

    we get

    for xf in tqdm.tqdm(range(0,dim)):
        for yu in range(0, dim):
            B = cp.einsum("ij,kim->kmj", U[:, :, yu], T_start[xf, :, :, :])
            C = cp.einsum("kmj,kbjs->bms", B, T_start)
            T_final[xf, :, yu, :] = cp.einsum("msc,bms->bs", U, C)
    

    Now we can also remove the outer loops in the same way:

    B = cp.einsum("iju,fkim->ufkmj", U, T_start)
    C = cp.einsum("ufkmj,kbjs->ufbms", B, T_start)
    T_final2 = cp.einsum("msc,ufbms->fbuc", U, C)
    
    

    But we are not done yet. Instead of doing multiple calls to einsum, we can also wrap everything in one call, to profit more from the non-python code in numpy/cp, as einsum supports multiple arguments. We can just plug in the corresponding entries, and add the indices. You may experience a slowdown though, as the evaluation order might not be ideal. So first we get

    C = cp.einsum("iju,fkim,kbjs->ufbms", U, T_start, T_start)
    T_final2 = cp.einsum("msc,ufbms->fbuc", U, C)
    

    and after eliminating the second call we get the following in the same way

    T_final2 = cp.einsum("msc,iju,fkim,kbjs->fbuc", U, U, T_start, T_start)
    

    While there is np.einsum_path, there is a much more sophisticated library called opt_einsum, which does exactly what you want. So you can replace the call above with

    from opt_einsum import contract 
    T_final2 = contract("msc,iju,fkim,kbjs->fbuc", U, U, T_start, T_start)
    

    Note that this does work with cupy too, as it supports various backends! If you want to explicitly what kind of contraction path is being used, you can use oe.contract_path().

    To verify how they behave, we can do some measurmenets (code below), and indeed, the original implementation is O(dim^5) while the optimized one is O(dim^4):

    timing diagram

    import numpy as np
    np.random.seed(0)
    cp = np
    from time import perf_counter
    import opt_einsum as oe
    def eval_dim(dim):
        print(f'printing {dim}')
        T_final = cp.random.rand(dim, dim, dim, dim)
        T_start = cp.random.rand(dim, dim, dim, dim)
        U = cp.random.rand(dim, dim, dim)
        start = perf_counter()
        for xf in range(0,dim):
            for yu in range(0, dim):
                B = cp.einsum("ij,kim->kmj", U[:, :, yu], T_start[xf, :, :, :])
                for xb in range(0, dim):
                    C = cp.einsum("kmj,kjs->ms", B, T_start[:, xb, :, :])
                    T_final[xf, xb, yu, :] = cp.einsum("msc,ms->c", U, C)
        print(time_orig := perf_counter()-start)
        start = perf_counter()
        T_final2 = oe.contract("msc,iju,fkim,kbjs->fbuc", U, U, T_start, T_start)
        print(np.allclose(T_final, T_final2))  # should all be small numbers
        print(time_oe := perf_counter()-start)
        return (time_orig, time_oe)
    
    def main():
        dims = np.arange(4, 25)
        times = np.array([eval_dim(dim) for dim in dims])
        import matplotlib.pyplot as plt
        plt.loglog(dims, times[:, 0], 'o-')
        plt.loglog(dims, times[:, 1], 'o-')
        plt.loglog(dims, 3e-7*dims**4, ':')
        plt.loglog(dims, 3e-7*dims**5, ':')
        plt.xlabel('dims'); plt.ylabel('time')
        plt.legend(['orig', 'opt_einsum', 'n^4', 'n^5'])
        plt.show()
    
    main()