Search code examples
pythonnumpylinear-algebratorch

Sparse and huge matrix multiplication in pytorch or numpy


I have a scenario where I need to multiply a small size vector a with a huge and highly sparse matrix b. Here's a simplified version of the code:

import numpy as np

B = 32
M = 10000000

a = np.random.rand(B)
b = np.random.rand(B, M)
b = b > 0.9

result = a @ b

In my actual use case, the b matrix is loaded from a np.memmap file due to its large size. Importantly, b remains unchanged throughout the process, and will be performing the inner product with different vectors a each time, so any pre-process on b to leverage its sparse nature is allowed.

I'm seeking suggestions on how to optimize this matrix multiplication speed. Any insights or code examples would be greatly appreciated.


Solution

  • Testing with a smaller M - I don't want to hang my machine with memory errors or long calcs:

    In [340]: B = 32
         ...: M = 10000
         ...: 
         ...: a = np.random.rand(B)
         ...: b = np.random.rand(B, M)
         ...: b = b > 0.9
         ...: 
         ...: 
    In [341]: timeit a@b
    1.23 ms ± 21.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
    In [342]: timeit np.einsum('i,ij',a,b)
    590 µs ± 3.83 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
    In [343]: timeit np.einsum('i,ij',a,b, optimize=True)
    1.57 ms ± 83 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
    

    I'm a little surprised what einsum is so much faster. Often einsum ends up using the same BLAS functions as matmul, especially if optimize is turned on. The big difference between B and M dimensions probably has something to do with it. BLAS type of code probably was written with 'squarish' matrices in mind.

    Trying scipy.sparse:

    In [344]: from scipy import sparse
    In [345]: bM = sparse.csr_matrix(b)
    In [346]: timeit bM = sparse.csr_matrix(b)
    3.19 ms ± 13.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    In [347]: bM
    Out[347]: 
    <32x10000 sparse matrix of type '<class 'numpy.bool_'>'
        with 32241 stored elements in Compressed Sparse Row format>
    In [348]: bM.indptr
    Out[348]: 
    array([    0,  1017,  2008,  3028,  4049,  5041,  5971,  6955,  7930,
            8957,  9986, 10978, 12031, 13050, 14067, 15072, 16142, 17140,
           18093, 19106, 20074, 21096, 22122, 23150, 24152, 25170, 26197,
           27232, 28271, 29233, 30246, 31226, 32241], dtype=int32)
    

    Creating the sparse matrix takes time. It in effect has to do a np.nonzero(b) on b to find the nonzero elements.

    Having made it, the multiplication proceeds as before (delegating to the bM own version of matrix multiplication (don't use np.matmul, np.dot, or np.einsum with bM):

    In [349]: res = a@bM
    In [350]: type(res)
    Out[350]: numpy.ndarray
    In [351]: np.allclose(res, a@b)
    Out[351]: True
    In [352]: timeit a@bM
    268 µs ± 12.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
    

    and the time better.

    I'm not sure how the csr_matrix() step would work with a memmap array. But once created that matrix can probably be stored in memory, saving that disk access time.