I am looking for an equivalent of numpy.where
to be used with the sparse representations that scipy offers (scipy.sparse
).
Is there anything that lets you deal with those matrices as if you where using an if-then-else statement?
UPDATE
To be more specific: I need where
as an if-then-else vectorialized function, i.e. in tasks like, for each value that equals K in matrix A, put a corresponding value in matrix B, else C.
You could use something like find
to retrieve the indexes of those entries that satisfy the logical condition, then negate them to find all the remaining ones, but for sparse matrices, isn't there a more compact way?
Here's a function that duplicates np.where
, when cond
, x
, and y
are matching sized sparse matrices.
def where1(cond, x):
# elements of x where cond
row, col, data = sparse.find(cond) # effectively the coo format
data = np.ones(data.shape, dtype=x.dtype)
zs = sparse.coo_matrix((data, (row, col)), shape=cond.shape)
xx = x.tocsr()[row, col][0]
zs.data[:] = xx
zs = zs.tocsr()
zs.eliminate_zeros()
return zs
def where2(cond, y):
# elements of y where not cond
row, col, data = sparse.find(cond)
zs = y.copy().tolil() # faster for this than the csr format
zs[row, col] = 0
zs = zs.tocsr()
zs.eliminate_zeros()
return zs
def where(cond, x, y):
# like np.where but with sparse matrices
ws1 = where1(cond, x)
# ws2 = where1(cond==0, y) # cond==0 is likely to produce a SparseEfficiencyWarning
ws2 = where2(cond, y)
ws = ws1 + ws2
# test against np.where
w = np.where(cond.A, x.A, y.A)
assert np.allclose(ws.A, w)
return ws
m,n, d = 100,90, 0.5
cs = sparse.rand(m,n,d)
xs = sparse.rand(m,n,d)
ys = sparse.rand(m,n,d)
print where(cs, xs, ys).A
Even after figuring out how to code where1
, it took further thought to figure out a way to apply the not
side of the problem without generating a warning. Its not as general or fast as the dense where
, but it illustrates the complexity that's involved in building sparse matrices this way.
It's worth noting that
np.where(cond) == np.nonzero(cond) # see doc
xs.nonzero() == (xs.row, xs.col) # for coo format
sparse.find(xs) == (row, col, data)
np.where
with x and y is equivalent to:
[xv if c else yv for (c,xv,yv) in zip(condition,x,y)] # see doc
The C code probably implements this with an nditer
, which is functionally like zip
, stepping through all the elements of the inputs and output. If the output is anywhere close to dense (e.g. y=2
), then the np.where
will be faster than this sparse substitute.