Search code examples
pythontensorflowarray-broadcastingragged

Broadcasting with ragged tensor


Define x as:

>>> import tensorflow as tf
>>> x = tf.constant([1, 2, 3])

Why does this normal tensor multiplication work fine with broacasting:

>>> tf.constant([[1, 2, 3], [4, 5, 6]]) * tf.expand_dims(x, axis=0)
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[ 1,  4,  9],
      [ 4, 10, 18]], dtype=int32)>

while this one with a ragged tensor does not?

>>> tf.ragged.constant([[1, 2, 3], [4, 5, 6]]) * tf.expand_dims(x, axis=0)
*** tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected 'tf.Tensor(False, shape=(), dtype=bool)' to be true. Summarized data: b'Unable to broadcast: dimension size mismatch in dimension'
1
b'lengths='
3
b'dim_size='
3, 3

How can I get a 1-D tensor to broadcast over a 2-D ragged tensor? (I am using TensorFlow 2.1.)


Solution

  • The problem will be resolved if you add ragged_rank=0 to the Ragged Tensor, as shown below:

    tf.ragged.constant([[1, 2, 3], [4, 5, 6]], ragged_rank=0) * tf.expand_dims(x, axis=0)
    

    Complete working code is:

    %tensorflow_version 2.x
    
    import tensorflow as tf
    x = tf.constant([1, 2, 3])
    
    print(tf.ragged.constant([[1, 2, 3], [4, 5, 6]], ragged_rank=0) * tf.expand_dims(x, axis=0))
    

    Output of the above code is:

    tf.Tensor(
    [[ 1  4  9]
     [ 4 10 18]], shape=(2, 3), dtype=int32)
    

    One more correction.

    As per the definition of Broadcasting, Broadcasting is the process of **making** tensors with different shapes have compatible shapes for elementwise operations, there is no need to specify tf.expand_dims explicitly, Tensorflow will take care of it.

    So, below code works and demonstrates the property of Broadcasting well:

    %tensorflow_version 2.x
    
    import tensorflow as tf
    x = tf.constant([1, 2, 3])
    
    print(tf.ragged.constant([[1, 2, 3], [4, 5, 6]], ragged_rank=0) * x)
    

    Output of the above code is:

    tf.Tensor(
    [[ 1  4  9]
     [ 4 10 18]], shape=(2, 3), dtype=int32)
    

    For more information, please refer this link.

    Hope this helps. Happy Learning!