Search code examples

Compute gradients for each time step of tf.while_loop

Given a TensorFlow tf.while_loop, how can I calculate the gradient of x_out with respect to all weights of the network for each time step?

network_input = tf.placeholder(tf.float32, [None])
steps = tf.constant(0.0)

weight_0 = tf.Variable(1.0)
layer_1 = network_input * weight_0

def condition(steps, x):
    return steps <= 5

def loop(steps, x_in):
    weight_1 = tf.Variable(1.0)
    x_out = x_in * weight_1
    steps += 1
    return [steps, x_out]

_, x_final = tf.while_loop(
    [steps, layer_1]

Some notes

  1. In my network the condition is dynamic. Different runs are going to run the while loop a different amount of times.
  2. Calling tf.gradients(x, tf.trainable_variables()) crashes with AttributeError: 'WhileContext' object has no attribute 'pred'. It seems like the only possibility to use tf.gradients within the loop is to calculate the gradient with respect to weight_1 and the current value of x_in / time step only without backpropagating through time.
  3. In each time step, the network is going to output a probability distribution over actions. The gradients are then needed for a policy gradient implementation.


  • You can't ever call tf.gradients inside tf.while_loop in Tensorflow based on this and this, I found this out the hard way when I was trying to create conjugate gradient descent entirely into the Tensorflow graph.

    But if I understand your model correctly, you could make your own version of an RNNCell and wrap it in a tf.dynamic_rnn, but the actual cell implementation will be a little complex since you need to evaluate a condition dynamically at runtime.

    For starters, you can take a look at Tensorflow's dynamic_rnn code here.

    Alternatively, dynamic graphs have never been Tensorflow's strong suite, so consider using other frameworks like PyTorch or you can try out eager_execution and see if that helps.