Search code examples
pythonnumpyscipysparse-matrix

scipy.sparse dot extremely slow in Python


The following code will not even finish on my system:

import numpy as np
from scipy import sparse
p = 100
n = 50
X = np.random.randn(p,n)
L = sparse.eye(p,p, format='csc')
X.T.dot(L).dot(X)

Is there any explanation why this matrix multiplication is hanging?


Solution

  • X.T.dot(L) is not, as you may think, a 50x100 matrix, but an array of 50x100 sparse matrices of 100x100

    >>> X.T.dot(L).shape
    (50, 100)
    >>> X.T.dot(L)[0,0]
    <100x100 sparse matrix of type '<type 'numpy.float64'>'
        with 100 stored elements in Compressed Sparse Column format>
    

    It seems that the problem is that X's dot method, it being an array, doesn't know about sparse matrices. So you must either convert the sparse matrix to dense using its todense or toarray method. The former returns a matrix object, the latter an array:

    >>> X.T.dot(L.todense()).dot(X)
    matrix([[  81.85399873,    3.75640482,    1.62443625, ...,    6.47522251,
                3.42719396,    2.78630873],
            [   3.75640482,  109.45428475,   -2.62737229, ...,   -0.31310651,
                2.87871548,    8.27537382],
            [   1.62443625,   -2.62737229,  101.58919604, ...,    3.95235372,
                1.080478  ,   -0.16478654],
            ..., 
            [   6.47522251,   -0.31310651,    3.95235372, ...,   95.72988689,
              -18.99209596,   17.31774553],
            [   3.42719396,    2.87871548,    1.080478  , ...,  -18.99209596,
              108.90045569,  -16.20312682],
            [   2.78630873,    8.27537382,   -0.16478654, ...,   17.31774553,
              -16.20312682,  105.37102461]])
    

    Alternatively, sparse matrices have a dot method that knows about arrays:

    >>> X.T.dot(L.dot(X))
    array([[  81.85399873,    3.75640482,    1.62443625, ...,    6.47522251,
               3.42719396,    2.78630873],
           [   3.75640482,  109.45428475,   -2.62737229, ...,   -0.31310651,
               2.87871548,    8.27537382],
           [   1.62443625,   -2.62737229,  101.58919604, ...,    3.95235372,
               1.080478  ,   -0.16478654],
           ..., 
           [   6.47522251,   -0.31310651,    3.95235372, ...,   95.72988689,
             -18.99209596,   17.31774553],
           [   3.42719396,    2.87871548,    1.080478  , ...,  -18.99209596,
             108.90045569,  -16.20312682],
           [   2.78630873,    8.27537382,   -0.16478654, ...,   17.31774553,
             -16.20312682,  105.37102461]])