Search code examples
pythonnumpymatrixscipysparse-matrix

Row-wise outer product on sparse matrices


Given two sparse scipy matrices A, B I want to compute the row-wise outer product.

I can do this with numpy in a number of ways. The easiest perhaps being

np.einsum('ij,ik->ijk', A, B).reshape(n, -1)

or

(A[:, :, np.newaxis] * B[:, np.newaxis, :]).reshape(n, -1)

where n is the number of rows in A and B.

In my case, however, going through dense matrices eat up way too much RAM. The only option I have found is thus to use a python loop:

sp.sparse.vstack((ra.T@rb).reshape(1,-1) for ra, rb in zip(A,B)).tocsr()

While using less RAM, this is very slow.

My question is thus, is there a sparse (RAM efficient) way to take the row-wise outer product of two matrices, which keeps things vectorized?

(A similar question is numpy elementwise outer product with sparse matrices but all answers there go through dense matrices.)


Solution

  • We can directly calculate the csr representation of the result. It's not superfast (~3 seconds on 100,000x768) but may be ok, depending on your use case:

    import numpy as np
    import itertools
    from scipy import sparse
    
    def spouter(A,B):
        N,L = A.shape
        N,K = B.shape
        drows = zip(*(np.split(x.data,x.indptr[1:-1]) for x in (A,B)))
        data = [np.outer(a,b).ravel() for a,b in drows]
        irows = zip(*(np.split(x.indices,x.indptr[1:-1]) for x in (A,B)))
        indices = [np.ravel_multi_index(np.ix_(a,b),(L,K)).ravel() for a,b in irows]
        indptr = np.fromiter(itertools.chain((0,),map(len,indices)),int).cumsum()
        return sparse.csr_matrix((np.concatenate(data),np.concatenate(indices),indptr),(N,L*K))
    
    A = sparse.random(100,768,0.03).tocsr()
    B = sparse.random(100,768,0.03).tocsr()
    
    print(np.all(np.einsum('ij,ik->ijk',A.A,B.A).reshape(100,-1) == spouter(A,B).A))
    
    A = sparse.random(100000,768,0.03).tocsr()
    B = sparse.random(100000,768,0.03).tocsr()
    
    from time import time
    T = time()
    C = spouter(A,B)
    print(time()-T)
    

    Sample run:

    True
    3.1073222160339355