I wrote a sample code to generate the real problem I am facing in my project. I am using an LSTM in tensorflow to model some time series data. Input dimensions are (10, 100, 1)
, that is, 10 instances, 100 time steps, and number of features is 1. The output is of the same shape.
What I want to achieve after training the model is to study the influence of each of the inputs to each output at each particular time step. In other words, I would like to see which input variables affect my output the most (or which input has the most influence on the output/maybe large gradient) at each time step. Here is the code for this problem:
tf.keras.backend.clear_session()
tf.random.set_seed(42)
model_input = tf.data.Dataset.from_tensor_slices(np.random.normal(size=(10, 100, 1)))
model_input = model_input.batch(10)
model_output = tf.data.Dataset.from_tensor_slices(np.random.normal(size=(10, 100, 1)))
model_output = model_output.batch(10)
my_dataset = tf.data.Dataset.zip((model_input, model_output))
m_inputs = tf.keras.Input(shape=(None, 1))
lstm_outputs = tf.keras.layers.LSTM(32, return_sequences=True)(m_inputs)
m_outputs = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(1))(lstm_outputs)
my_model = tf.keras.Model(m_inputs, m_outputs, name="my_model")
my_optimizer=tf.keras.optimizers.Adam(learning_rate=0.001)
my_loss_fn = tf.keras.losses.MeanSquaredError()
my_epochs = 3
for epoch in range(my_epochs):
for step, (x_batch_tr, y_batch_tr) in enumerate(my_dataset):
x += 1
# open a gradient tape to record the operations run during the forward pass, which enables autodifferentiation
with tf.GradientTape() as tape:
# Run the forward pass of the layer
logits = my_model(x_batch_tr, training=True)
# compute the loss value for this mismatch
loss_value = my_loss_fn(y_batch_tr, logits)
# use the gradient tape to automatically retrieve the gradients of the trainable variables with respect to the loss.
grads = tape.gradient(loss_value, my_model.trainable_weights)
# Run one step of gradient descent by updating the value of the variables to minimize the loss.
my_optimizer.apply_gradients(zip(grads, my_model.trainable_weights))
print(f"Step {step}, loss: {loss_value}")
print("\n\nCalculate gradient of ouptuts w.r.t inputs\n\n")
for step, (x_batch_tr, y_batch_tr) in enumerate(my_dataset):
# open a gradient tape to record the operations run during the forward pass, which enables autodifferentiation
with tf.GradientTape() as tape:
tape.watch(x_batch_tr)
# Run the forward pass of the layer
logits = my_model(x_batch_tr, training=True)
#tape.watch(logits[:, 10, :]) # this didn't help
# compute the loss value for this mismatch
loss_value = my_loss_fn(y_batch_tr, logits)
# use the gradient tape to automatically retrieve the gradients of the trainable variables with respect to the loss.
# grads = tape.gradient(logits, x_batch_tr) # This works
# print(grads.numpy().shape) # This works
grads = tape.gradient(logits[:, 10, :], x_batch_tr)
print(grads)
In other words, I would like to pay attention to the inputs that affect my output the most (at each particular time step).
To me grads = tape.gradient(logits, x_batch_tr)
won't do the job cuz this will add the gradients from all outputs w.r.t each inputs.
However, the gradients are always None.
Any help is much appreciated!
You can use tf.GradientTape.batch_jacobian
to get precisely that information:
grads = tape.batch_jacobian(logits, x_batch_tr)
print(grads.shape)
# (10, 100, 1, 100, 1)
Here, grads[i, t1, f1, t2, f2]
gives you, for the example i
, the gradient of output feature f1
at time t1
with respect to input feature f2
at time t2
. If, as in your case, you only have one feature, you can just say that grads[i, t1, 0, t2, 0]
gives you the gradient of t1
with respect to t2
. Conveniently, you can also aggregate different axes or slices of this result to get aggregated gradients. For example, tf.reduce_sum(grads[:, :, :, :10], axis=3)
would give you the gradient of each output time step with respect to the first ten input time steps.
About getting None
gradients in your example, I think it is because you are doing the slicing operation outside of the gradient tape context, so the gradient tracking is lost.