Search code examples
pythonnumpytensornumpy-einsumtensordot

Memory and time in tensor operations python


Goal My goal is to calculate the tensor given by the formula which you can see below. The indices i, j, k, l run from 0 to 40 and p, m, x from 0 to 80.

The formula for the tensor

Tensordot approach This summation is just contracting 6 indices of enormous tensor. I tried to do it by tensor dot which allows for such calculation but then my problem is memory even if I do one tensor dot followed by the other. (I work in colab so I have 12GB RAM available)

Nested loops approach But there are some addtitional symmetries governing the B matrix i.e. the only non zero elements of B{ijpx} are such that i+j= p+x. Therefore I was able to write p and m as a function of x (p=i+j-x, m=k+l-x)and then I did 5 loops namely for i,j,k,l,x but then on the other hand the timing is the problem since calculation takes 136 seconds and I want to repeat it many times.

Timing goal in nested loop approach Reduction of the time by a factor of ten would be satisfactory but if it would be possible to reduce it by a factor of 100 it would be more than enough.

Do you have any ideas either for going around memory problem or reducing the timing? How do you handle such summations with additional constraints?

(Remark: The matrix A is symmetric and I have not used this fact so far. There are no more symmetries.)

Here is the code for nested loop:

for i in range (0,40):
  for j in range (0,40):
    for k in range (0,40):
      for l in range (0,40):
            Sum=0
            for x in range (0,80):
              p=i+j-x
              m=k+l-x
              if p>=0 and p<80 and m>=0 and m<80:
                Sum += A[p,m]*B[i,j,p,x]*B[k,l,m,x]
            T[i,j,k,l]= Sum

And the code for the tensor dot approach:

P=np.tensordot(A,B,axes=((0),(2)))
T=np.tensordot(P,B,axes=((0,3),(2,3)))

Solution

  • Numba might be your best bet here. I put together this function based on your code. I changed it a bit to avoid some unnecessary iterations and the if block:

    import numpy as np
    import numba as nb
    
    @nb.njit(parallel=True)
    def my_formula_nb(A, B):
        di, dj, dx, _ = B.shape
        T = np.zeros((di, dj, di, dj), dtype=A.dtype)
        for i in nb.prange (di):
            for j in nb.prange (dj):
                for k in nb.prange (di):
                    for l in nb.prange (dj):
                        sum = 0
                        x_start = max(0, i + j - dx + 1, k + l - dx + 1)
                        x_end = min(dx, i + j + 1, k + l + 1)
                        for x in range(x_start, x_end):
                            p = i + j - x
                            m = k + l - x
                            sum += A[p, m] * B[i, j, p, x] * B[k, l, m, x]
                        T[i, j, k, l] = sum
        return T
    

    Let's see it in action:

    import numpy as np
    
    def make_problem(di, dj, dx):
        a = np.random.rand(dx, dx)
        a = a + a.T
        b = np.random.rand(di, dj, dx, dx)
        b_ind = np.indices(b.shape)
        b_mask = b_ind[0] + b_ind[1] != b_ind[2] + b_ind[3]
        b[b_mask] = 0
        return a, b
    
    # Generate a problem
    np.random.seed(100)
    a, b = make_problem(15, 20, 25)
    # Solve with Numba function
    t1 = my_formula_nb(a, b)
    # Solve with einsum
    t2 = np.einsum('pm,ijpx,klmx->ijkl', a, b, b)
    # Check result
    print(np.allclose(t1, t2))
    # True
    
    # Benchmark (IPython)
    %timeit np.einsum('pm,ijpx,klmx->ijkl', a, b, b)
    # 4.5 s ± 39.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    %timeit my_formula_nb(a, b)
    # 6.06 ms ± 20.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    

    As you can see, the Numba solution is about three orders of magnitude faster, and it should not take any more memory than necessary.