Search code examples
pythonmatlabnumpyscipysparse-matrix

What is most efficient way of setting row to zeros for a sparse scipy matrix?


I'm trying to convert the following MATLAB code to Python and am having trouble finding a solution that works in any reasonable amount of time.

M = diag(sum(a)) - a;
where = vertcat(in, out);
M(where,:) = 0;
M(where,where) = 1;

Here, a is a sparse matrix and where is a vector (as are in/out). The solution I have using Python is:

M = scipy.sparse.diags([degs], [0]) - A
where = numpy.hstack((inVs, outVs)).astype(int)
M = scipy.sparse.lil_matrix(M)
M[where, :] = 0  # This is the slowest line
M[where, where] = 1
M = scipy.sparse.csc_matrix(M)

But since A is 334863x334863, this takes like three minutes. If anyone has any suggestions on how to make this faster, please contribute them! For comparison, MATLAB does this same step imperceptibly fast.

Thanks!


Solution

  • The solution I use for similar task attributes to @seberg and do not convert to lil format:

    import scipy.sparse
    import numpy
    import time
    
    def csr_row_set_nz_to_val(csr, row, value=0):
        """Set all nonzero elements (elements currently in the sparsity pattern)
        to the given value. Useful to set to 0 mostly.
        """
        if not isinstance(csr, scipy.sparse.csr_matrix):
            raise ValueError('Matrix given must be of CSR format.')
        csr.data[csr.indptr[row]:csr.indptr[row+1]] = value
    
    def csr_rows_set_nz_to_val(csr, rows, value=0):
        for row in rows:
            csr_row_set_nz_to_val(csr, row)
        if value == 0:
            csr.eliminate_zeros()
    

    wrap your evaluations with timing

    def evaluate(size):
        degs = [1]*size
        inVs = list(xrange(1, size, size/25))
        outVs = list(xrange(5, size, size/25))
        where = numpy.hstack((inVs, outVs)).astype(int)
        start_time = time.time()
        A = scipy.sparse.csc_matrix((size, size))
        M = scipy.sparse.diags([degs], [0]) - A
        csr_rows_set_nz_to_val(M, where)
        return time.time()-start_time
    

    and test its performance:

    >>> print 'elapsed %.5f seconds' % evaluate(334863)
    elapsed 0.53054 seconds