Search code examples
tensorflowtensorflow2.0tensorflow-datasetsgradienttape

Why is batch_size being multiplied to GradientTape results in Tensorflow?


I'm trying to get the gradients of a loss function w.r.t to another tensor. But the gradients are being multiplied by input batch size that I feed into my model.

import tensorflow as tf
from tensorflow.keras import Sequential, layers

#Sample States and Returns
states = tf.random.uniform(shape = (100,4))
returns = tf.constant([float(i) for i in range(100)])

#Creating dataset to feed data to model
states = tf.data.Dataset.from_tensor_slices(states)
returns = tf.data.Dataset.from_tensor_slices(returns)

#zipping datasets into one
batch_size = 4
dataset = tf.data.Dataset.zip((states, returns)).batch(batch_size)

model = Sequential([layers.Dense(128, input_shape =(4,), activation = tf.nn.relu), 
                    layers.Dense(1, activation = tf.nn.tanh)])

for state_batch, returns_batch in dataset:
    with tf.GradientTape(persistent=True) as tape:
        values = model(state_batch)
        loss =  returns_batch - values  
    
    # d_loss/d_values should be -1.0,  but i'm getting -1.0 * batch_size
    print(tape.gradient(loss,values))
    break
Output:
tf.Tensor(
[[-4.]
 [-4.]
 [-4.]
 [-4.]], shape=(4, 1), dtype=float32)

Expected Output:
tf.Tensor(
[[-1.]
 [-1.]
 [-1.]
 [-1.]], shape=(4, 1), dtype=float32)

From the code, you can see that loss = returns - values. So it should be d_loss/d_values = -1.0 , but the result I'm getting is d_loss/d_values = -1.0 * batch_size. Someone please point out why this is happening? How can I get the real results?

colab link : https://colab.research.google.com/drive/1x4pyGJ5ccRVSMzDAeLzcPXRtO7cNFnJf?usp=sharing


Solution

  • The problem is in this line:

    loss = returns_batch - values  
    

    Here, returns_batch has shape (4,), but values has shape (4, 1). The subtraction operation broadcasts the tensors, resulting in a loss tensor that has shape (4, 4), with four repeated columns. For this reason, changing a single value of values affects four elements of returns_batch, hence the scaled gradient value. You can fix it for example like this:

    loss = returns_batch - tf.squeeze(values, axis=1)