Search code examples
pythontensorflowbroadcasting

Tensorflow 2.0 `tf.multiply` method gives unexpected results


I am bit confused about output shape of this operation

>>> eps = tf.random.uniform((3))
>>> images = tf.random.normal((3, 28, 28, 1))
>>> output = eps * images
>>> output.get_shape()
(3, 28, 28, 3)

I want this to multiply every single scalar in eps with each image of shape (28, 28, 1) in images to get the output shape (3, 28, 28, 1)

Something like this

>>> output = []
>>> output.append(eps[0] * images[0])
>>> output.append(eps[1] * images[1])
>>> output.append(eps[2] * images[2])
>>> output = tf.convert_to_tensor(output)
>>> output.get_shape()
(3, 28, 28, 1)

Please help.


Solution

  • That is due to broadcasting (the link is from NumPy documentation, but it works the same way in TensorFlow). If you want to "match" the single dimension of eps to the first dimension of images, you need to add extra singleton dimensions to eps so the broadcasting works as you expect:

    eps = tf.random.uniform((3))
    # Add dimensions for broadcasting
    eps = tf.reshape(eps, [-1, 1, 1, 1])
    output = eps * images
    print(output.get_shape())
    # (3, 28, 28, 1)
    

    Alternatively, you can directly create eps with that shape:

    eps = tf.random.uniform((3, 1, 1, 1))