Search code examples
pythontensorflowtensorflow2.0tokenizesummarization

How to join/concat/combine ragged tensors in tensorflow?


so basically I have a ragged tensor (e.g. [[1, 2, 3], [4, 5], [6]]) and I want to concat them with a special character in between them, like an specific number, say 0. So the result would be [[1, 2, 3, 0, 4, 5, 0, 6]]. So this is sth like joining strings but I want to do it with ragged integers. I have no solutions for this to be able to turn it into a @tf.function. Also the purpose of this is to concatenate tokens of a documents sentences, and that special character is to indicate where a sentence ends and another starts.


Solution

  • Try using tf.concat and ragged.merge_dims:

    import tensorflow as tf
    
    ragged = tf.ragged.constant([[1, 2, 3], [4, 5], [6]])
    rows = ragged.bounding_shape()[0]
    ragged = tf.concat([ragged, tf.concat([tf.expand_dims(tf.repeat([0], repeats=rows-1), axis=-1), tf.ragged.constant([[]], dtype=tf.int32)], axis=0)], axis=-1)
    ragged = tf.expand_dims(ragged.merge_dims(0, 1), axis=0)
    print(ragged)
    # tf.Tensor([[1 2 3 0 4 5 0 6]], shape=(1, 8), dtype=int32)