Search code examples
tensorflowsparse-matrix

Get nonzeros row of a SparseTensor


I want to get all nonzeros from a row in a SparseTensor so "m" is the sparse tensor object that I have and row is a row I want to get all nonzeros values and indices from. So I want to return an array of the pair that is [(index, values)]. I hope I can get som help on the subject.

def nonzeros( m, row):
    res = []
    indices = m.indices
    values = m.values
    userindices = tf.where(tf.equal(indices[:,0], tf.constant(0, dtype=tf.int64)))
    res = tf.map_fn(lambda index:(indices[index][1], values[index]), userindices)
    return res

Error message in terminal

TypeError: Input 'strides' of 'StridedSlice' Op has type int32 that does not match type int64 of argument 'begin'.

EDIT: Input for nonzeros cm is a coo_matrix with values

m = tf.SparseTensor(indices=np.array([row,col]).T,
                        values=cm.data,
                        dense_shape=[10, 10])
nonzeros(m, 1)

if the data is

[[ 0.  1.  0.  0.  0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  1.  0.  0.  0.  0.  2.]
 [ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  1.]]

result should be

[index, value]
[4,1]
[9,2]

Solution

  • The problem is that index inside the lambda is a Tensor and you can't use that directly to index into e.g. indices. You could use tf.gather instead. Also, you didn't make use of the row parameter in the code you posted.

    Try this instead:

    import tensorflow as tf
    import numpy as np
    
    def nonzeros(m, row):
        indices = m.indices
        values = m.values
        userindices = tf.where(tf.equal(indices[:, 0], row))
        found_idx = tf.gather(indices, userindices)[:, 0, 1]
        found_vals = tf.gather(values, userindices)[:, 0:1]
        res = tf.concat(1, [tf.expand_dims(tf.cast(found_idx, tf.float64), -1), found_vals])
        return res
    
    data = np.array([[0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.],
                    [0., 0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  2.]])
    
    m = tf.SparseTensor(indices=np.array([[0, 1], [0, 9], [1, 4], [1, 9]]),
                        values=np.array([1.0, 1.0, 1.0, 2.0]),
                        shape=[2, 10])
    
    with tf.Session() as sess:
        result = nonzeros(m, 1)
        print(sess.run(result))
    

    which prints:

    [[ 4.  1.]
     [ 9.  2.]]