Search code examples
pythontensorflowlstmgradient-descent

Calculating the derivates of the output with respect to input for a give time step in LSTM tensorflow2.0


I wrote a sample code to generate the real problem I am facing in my project. I am using an LSTM in tensorflow to model some time series data. Input dimensions are (10, 100, 1), that is, 10 instances, 100 time steps, and number of features is 1. The output is of the same shape.

What I want to achieve after training the model is to study the influence of each of the inputs to each output at each particular time step. In other words, I would like to see which input variables affect my output the most (or which input has the most influence on the output/maybe large gradient) at each time step. Here is the code for this problem:

tf.keras.backend.clear_session()
tf.random.set_seed(42)

model_input = tf.data.Dataset.from_tensor_slices(np.random.normal(size=(10, 100, 1)))
model_input = model_input.batch(10)
model_output = tf.data.Dataset.from_tensor_slices(np.random.normal(size=(10, 100, 1)))
model_output = model_output.batch(10)

my_dataset = tf.data.Dataset.zip((model_input, model_output))

m_inputs = tf.keras.Input(shape=(None, 1))

lstm_outputs = tf.keras.layers.LSTM(32, return_sequences=True)(m_inputs)
m_outputs = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(1))(lstm_outputs)

my_model = tf.keras.Model(m_inputs, m_outputs, name="my_model")

my_optimizer=tf.keras.optimizers.Adam(learning_rate=0.001)
my_loss_fn = tf.keras.losses.MeanSquaredError()

my_epochs = 3

for epoch in range(my_epochs):

    for step, (x_batch_tr, y_batch_tr) in enumerate(my_dataset):
        x += 1
        # open a gradient tape to record the operations run during the forward pass, which enables autodifferentiation
        with tf.GradientTape() as tape:

            # Run the forward pass of the layer
            logits = my_model(x_batch_tr, training=True)

            # compute the loss value for this mismatch
            loss_value = my_loss_fn(y_batch_tr, logits)

        # use the gradient tape to automatically retrieve the gradients of the trainable variables with respect to the loss.
        grads = tape.gradient(loss_value, my_model.trainable_weights)

        # Run one step of gradient descent by updating the value of the variables to minimize the loss.
        my_optimizer.apply_gradients(zip(grads, my_model.trainable_weights))

        print(f"Step {step}, loss: {loss_value}")


print("\n\nCalculate gradient of ouptuts w.r.t inputs\n\n")

for step, (x_batch_tr, y_batch_tr) in enumerate(my_dataset):
    # open a gradient tape to record the operations run during the forward pass, which enables autodifferentiation
    with tf.GradientTape() as tape:

        tape.watch(x_batch_tr)

        # Run the forward pass of the layer
        logits = my_model(x_batch_tr, training=True)
        #tape.watch(logits[:, 10, :])   # this didn't help
        # compute the loss value for this mismatch
        loss_value = my_loss_fn(y_batch_tr, logits)

    # use the gradient tape to automatically retrieve the gradients of the trainable variables with respect to the loss.
#     grads = tape.gradient(logits, x_batch_tr)   # This works
#     print(grads.numpy().shape)                  # This works
    grads = tape.gradient(logits[:, 10, :], x_batch_tr)
    print(grads)

In other words, I would like to pay attention to the inputs that affect my output the most (at each particular time step).

To me grads = tape.gradient(logits, x_batch_tr) won't do the job cuz this will add the gradients from all outputs w.r.t each inputs.

However, the gradients are always None.

Any help is much appreciated!


Solution

  • You can use tf.GradientTape.batch_jacobian to get precisely that information:

    grads = tape.batch_jacobian(logits, x_batch_tr)
    print(grads.shape)
    # (10, 100, 1, 100, 1)
    

    Here, grads[i, t1, f1, t2, f2] gives you, for the example i, the gradient of output feature f1 at time t1 with respect to input feature f2 at time t2. If, as in your case, you only have one feature, you can just say that grads[i, t1, 0, t2, 0] gives you the gradient of t1 with respect to t2. Conveniently, you can also aggregate different axes or slices of this result to get aggregated gradients. For example, tf.reduce_sum(grads[:, :, :, :10], axis=3) would give you the gradient of each output time step with respect to the first ten input time steps.

    About getting None gradients in your example, I think it is because you are doing the slicing operation outside of the gradient tape context, so the gradient tracking is lost.