Search code examples
pythonarraysnumpypytorchmemory-efficient

Numpy sum elements in a multi-dimensional array according to indices


I am dealing with a very large multi-dimensional data , but let me take a 2D array for example. Given a value array that is changing every iteration,

arr = np.array([[ 1, 2, 3, 4, 5], [5, 6, 7, 8, 9]]) # a*b

and an index array that is fixed all the time.

idx = np.array([[[0, 1, 1], [-1, -1, -1]],
                [[5, 1, 3], [1, -1, -1]]]) # n*h*w, where n = a*b,

Here -1 means no index will be applied. And I wish to get a result

res = np.array([[1+2+2, 0],
                [5+2+4, 2]]) # h*w

In real practice, I am doing with a very large 3D tensor (n ~ trillions), with a very sparse idx (i.e. lots of -1). As idx is fixed, my current solution is to pre-compute a n*(h*w) array index_tensor by filling 0 and 1, and then do

tmp = arr.reshape(1, n)
res = (tmp @ index_tensor).reshape([h,w])

It works fine but takes a huge memory to store the index_tensor. Is there any approach that I can take the advantage of the sparsity and unchangeableness of idx to reduce the memory cost and keep a fair running speed in python (using numpy or pytorch would be the best)? Thanks in advance!


Solution

  • Ignoring the -1 complication for the moment, the straight forward indexing and summation is:

    In [58]: arr = np.array([[ 1, 2, 3, 4, 5], [5, 6, 7, 8, 9]])
    In [59]: idx = np.array([[[0, 1, 1], [2, 4, 6]],
        ...:                 [[5, 1, 3], [1, -1, -1]]])
    In [60]: arr.flat[idx]
    Out[60]: 
    array([[[1, 2, 2],
            [3, 5, 6]],
    
           [[5, 2, 4],
            [2, 9, 9]]])
    In [61]: _.sum(axis=-1)
    Out[61]: 
    array([[ 5, 14],
           [11, 20]])
    

    One way (not necessarily fast or memory efficient) of dealing with the -1 is with a masked array:

    In [62]: mask = idx<0
    In [63]: mask
    Out[63]: 
    array([[[False, False, False],
            [False, False, False]],
    
           [[False, False, False],
            [False,  True,  True]]])
    
    In [65]: ma = np.ma.masked_array(Out[60],mask)
    In [67]: ma
    Out[67]: 
    masked_array(
      data=[[[1, 2, 2],
             [3, 5, 6]],
    
            [[5, 2, 4],
             [2, --, --]]],
      mask=[[[False, False, False],
             [False, False, False]],
    
            [[False, False, False],
             [False,  True,  True]]],
      fill_value=999999)
    In [68]: ma.sum(axis=-1)
    Out[68]: 
    masked_array(
      data=[[5, 14],
            [11, 2]],
      mask=[[False, False],
            [False, False]],
      fill_value=999999)
    

    Masked arrays deal with operations like the sum by replacing the masked values with something neutral, such as 0 for the case of sums.

    (I may revisit this in the morning).

    sum with matrix product

    In [72]: np.einsum('ijk,ijk->ij',Out[60],~mask)
    Out[72]: 
    array([[ 5, 14],
           [11,  2]])
    

    This is more direct, and faster, than the masked array approach.

    You haven't elaborated on constructing the index_tensor so I won't try to compare it.

    Another possibility is to pad the array with a 0, and adjust indexing:

    In [83]: arr1 = np.hstack((0,arr.ravel()))
    In [84]: arr1
    Out[84]: array([0, 1, 2, 3, 4, 5, 5, 6, 7, 8, 9])
    In [85]: arr1[idx+1]
    Out[85]: 
    array([[[1, 2, 2],
            [3, 5, 6]],
    
           [[5, 2, 4],
            [2, 0, 0]]])
    In [86]: arr1[idx+1].sum(axis=-1)
    Out[86]: 
    array([[ 5, 14],
           [11,  2]])
    

    sparse

    A first stab at using sparse matrix:

    Reshape idx to 2d:

    In [141]: idx1 = np.reshape(idx,(4,3))
    

    make a sparse tensor from that. For a start I'll go the iterative lil approach, though usually constructing coo (or even csr) inputs directly is faster:

    In [142]: M = sparse.lil_matrix((4,10),dtype=int)
         ...: for i in range(4):
         ...:     for j in range(3):
         ...:         v = idx1[i,j]
         ...:         if v>=0:
         ...:            M[i,v] = 1
         ...: 
    In [143]: M
    Out[143]: 
    <4x10 sparse matrix of type '<class 'numpy.int64'>'
        with 9 stored elements in List of Lists format>
    In [144]: M.A
    Out[144]: 
    array([[1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
           [0, 0, 1, 0, 1, 0, 1, 0, 0, 0],
           [0, 1, 0, 1, 0, 1, 0, 0, 0, 0],
           [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]])
    

    This can then be used for a sum of products:

    In [145]: [email protected]()
    Out[145]: array([ 3, 14, 11,  2])
    

    Using [email protected]() is essentially what you do. While M is sparse, arr is not. For this case M.A@ is faster than M@.