Search code examples
pythontensorflowsparse-matrixtensorflow2.0

Efficient boolean masking with Tensorflow SparseTensors


So, I want to mask out entire rows of a SparseTensor. This would be easy to do with tf.boolean_mask, but there isn't an equivalent for SparseTensors. Currently, something that is possible is for me to just go through all of the indices in SparseTensor.indices and filter out all of the ones that aren't a masked row, e.g.:

masked_indices = list(filter(lambda index: masked_rows[index[0]], indices))

where masked_rows is a 1D array of whether or not the row at that index is masked.

However, this is really slow, since my SparseTensor is fairly large (it has 90k indices, but will be growing to be significantly larger). It takes quite a few seconds on a single data point, before I even apply SparseTensor.mask on the filtered indices. Another flaw of the approach is that it doesn't actually remove the rows, either (although, in my case, a row of all zeros is just as fine).

Is there a better way to mask a SparseTensor by row, or is this the best approach?


Solution

  • You can do that like this:

    import tensorflow as tf
    
    def boolean_mask_sparse_1d(sparse_tensor, mask, axis=0):  # mask is assumed to be 1D
        mask = tf.convert_to_tensor(mask)
        ind = sparse_tensor.indices[:, axis]
        mask_sp = tf.gather(mask, ind)
        new_size = tf.math.count_nonzero(mask)
        new_shape = tf.concat([sparse_tensor.shape[:axis], [new_size],
                               sparse_tensor.shape[axis + 1:]], axis=0)
        new_shape = tf.dtypes.cast(new_shape, tf.int64)
        mask_count = tf.cumsum(tf.dtypes.cast(mask, tf.int64), exclusive=True)
        masked_idx = tf.boolean_mask(sparse_tensor.indices, mask_sp)
        new_idx_axis = tf.gather(mask_count, masked_idx[:, axis])
        new_idx = tf.concat([masked_idx[:, :axis],
                             tf.expand_dims(new_idx_axis, 1),
                             masked_idx[:, axis + 1:]], axis=1)
        new_values = tf.boolean_mask(sparse_tensor.values, mask_sp)
        return tf.SparseTensor(new_idx, new_values, new_shape)
    
    # Test
    sp = tf.SparseTensor([[1], [3], [4], [6]], [1, 2, 3, 4], [7])
    mask = tf.constant([True, False, True, True, False, False, True])
    out = boolean_mask_sparse_1d(sp, mask)
    print(out.indices.numpy())
    # [[2]
    #  [3]]
    print(out.values.numpy())
    # [2 4]
    print(out.shape)
    # (4,)