Search code examples
pythonnumpyscipysparse-matrix

Scipy equivalent of numpy where for sparse matrices


I am looking for an equivalent of numpy.where to be used with the sparse representations that scipy offers (scipy.sparse). Is there anything that lets you deal with those matrices as if you where using an if-then-else statement?

UPDATE To be more specific: I need where as an if-then-else vectorialized function, i.e. in tasks like, for each value that equals K in matrix A, put a corresponding value in matrix B, else C. You could use something like find to retrieve the indexes of those entries that satisfy the logical condition, then negate them to find all the remaining ones, but for sparse matrices, isn't there a more compact way?


Solution

  • Here's a function that duplicates np.where, when cond, x, and y are matching sized sparse matrices.

    def where1(cond, x):
        # elements of x where cond
        row, col, data = sparse.find(cond) # effectively the coo format
        data = np.ones(data.shape, dtype=x.dtype)
        zs = sparse.coo_matrix((data, (row, col)), shape=cond.shape)
        xx = x.tocsr()[row, col][0]
        zs.data[:] = xx
        zs = zs.tocsr()
        zs.eliminate_zeros()
        return zs
    
    def where2(cond, y):
        # elements of y where not cond
        row, col, data = sparse.find(cond)
        zs = y.copy().tolil() # faster for this than the csr format
        zs[row, col] = 0
        zs = zs.tocsr()
        zs.eliminate_zeros()
        return zs
    
    def where(cond, x, y):
        # like np.where but with sparse matrices
        ws1 = where1(cond, x)
        # ws2 = where1(cond==0, y) # cond==0 is likely to produce a SparseEfficiencyWarning
        ws2 = where2(cond, y)
        ws = ws1 + ws2
        # test against np.where
        w = np.where(cond.A, x.A, y.A)
        assert np.allclose(ws.A, w)
        return ws
    
    m,n, d = 100,90, 0.5
    cs = sparse.rand(m,n,d)
    xs = sparse.rand(m,n,d)
    ys = sparse.rand(m,n,d)
    print where(cs, xs, ys).A
    

    Even after figuring out how to code where1, it took further thought to figure out a way to apply the not side of the problem without generating a warning. Its not as general or fast as the dense where, but it illustrates the complexity that's involved in building sparse matrices this way.

    It's worth noting that

    np.where(cond) == np.nonzero(cond) # see doc
    
    xs.nonzero() == (xs.row, xs.col) # for coo format
    sparse.find(xs) == (row, col, data)
    

    np.where with x and y is equivalent to:

    [xv if c else yv for (c,xv,yv) in zip(condition,x,y)]  # see doc
    

    The C code probably implements this with an nditer, which is functionally like zip, stepping through all the elements of the inputs and output. If the output is anywhere close to dense (e.g. y=2), then the np.where will be faster than this sparse substitute.