Search code examples
pythontensorflowconcatenationraggedragged-tensors

Broadcast and concatenate ragged tensors


I have a ragged tensor of dimensions [BATCH_SIZE, TIME_STEPS, EMBEDDING_DIM]. I want to augment the last axis with data from another tensor of shape [BATCH_SIZE, AUG_DIM]. Each time step of a given example gets augmented with the same value.

If the tensor wasn't ragged with varying TIME_STEPS for each example, I could simply reshape the second tensor with tf.repeat and then use tf.concat:

import tensorflow as tf


# create data
# shape: [BATCH_SIZE, TIME_STEPS, EMBEDDING_DIM]
emb = tf.constant([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [0, 0, 0]]])
# shape: [BATCH_SIZE, 1, AUG_DIM]
aug = tf.constant([[[8]], [[9]]])

# concat
aug = tf.repeat(aug, emb.shape[1], axis=1)
emb_aug = tf.concat([emb, aug], axis=-1)

This doesn't approach work when emb is ragged since emb.shape[1] is unknown and varies across examples:

# rag and remove padding
emb = tf.RaggedTensor.from_tensor(emb, padding=(0, 0, 0))

# reshape for augmentation - this doesn't work
aug = tf.repeat(aug, emb.shape[1], axis=1)

ValueError: Attempt to convert a value (None) with an unsupported type (<class 'NoneType'>) to a Tensor.

The goal is to create a ragged tensor emb_aug which looks like this:

<tf.RaggedTensor [[[1, 2, 3, 8], [4, 5, 6, 8]], [[1, 2, 3 ,9]]]>

Any ideas?


Solution

  • The easiest way to do this is to just make your ragged tensor a regular tensor by using tf.RaggedTensor.to_tensor() and then do the rest of your solution. I'll assume that you need the tensor to remain ragged. The key is to find the row_lengths of each batch in your ragged tensor, and then use this information to make your augmentation tensor ragged.

    Example:

    import tensorflow as tf
    
    
    # data
    emb = tf.constant([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [0, 0, 0]]])
    aug = tf.constant([[[8]], [[9]]])
    
    # make embeddings ragged for testing
    emb_r = tf.RaggedTensor.from_tensor(emb, padding=(0, 0, 0))
    
    print(emb_r.shape)
    # (2, None, 3)
    

    Here we'll use a combination of row_lengths and sequence_mask to create a new ragged tensor.

    # find the row lengths of the embeddings
    rl = emb_r.row_lengths()
    
    print(rl)
    # tf.Tensor([2 1], shape=(2,), dtype=int64)
    
    # find the biggest row length
    max_rl = tf.math.reduce_max(rl)
    
    print(max_rl)
    # tf.Tensor(2, shape=(), dtype=int64)
    
    # repeat the augmented data `max_rl` number of times
    aug_t = tf.repeat(aug, repeats=max_rl, axis=1)
    
    print(aug_t)
    # tf.Tensor(
    # [[[8]
    #   [8]]
    # 
    #  [[9]
    #   [9]]], shape=(2, 2, 1), dtype=int32)
    
    # create a mask
    msk = tf.sequence_mask(rl)
    
    print(msk)
    # tf.Tensor(
    # [[ True  True]
    #  [ True False]], shape=(2, 2), dtype=bool)
    

    From here we can use tf.ragged.boolean_mask to make the augmented data ragged

    # make the augmented data a ragged tensor
    aug_r = tf.ragged.boolean_mask(aug_t, msk)
    print(aug_r)
    # <tf.RaggedTensor [[[8], [8]], [[9]]]>
    
    # concatenate!
    output = tf.concat([emb_r, aug_r], 2)
    print(output)
    # <tf.RaggedTensor [[[1, 2, 3, 8], [4, 5, 6, 8]], [[1, 2, 3, 9]]]>
    

    You can find the list of tensorflow methods that support ragged tensors here