Given a TensorFlow tf.while_loop
, how can I calculate the gradient of x_out
with respect to all weights of the network for each time step?
network_input = tf.placeholder(tf.float32, [None])
steps = tf.constant(0.0)
weight_0 = tf.Variable(1.0)
layer_1 = network_input * weight_0
def condition(steps, x):
return steps <= 5
def loop(steps, x_in):
weight_1 = tf.Variable(1.0)
x_out = x_in * weight_1
steps += 1
return [steps, x_out]
_, x_final = tf.while_loop(
condition,
loop,
[steps, layer_1]
)
Some notes
tf.gradients(x, tf.trainable_variables())
crashes with AttributeError: 'WhileContext' object has no attribute 'pred'
. It seems like the only possibility to use tf.gradients
within the loop is to calculate the gradient with respect to weight_1
and the current value of x_in
/ time step only without backpropagating through time.You can't ever call tf.gradients
inside tf.while_loop
in Tensorflow based on this and this, I found this out the hard way when I was trying to create conjugate gradient descent entirely into the Tensorflow
graph.
But if I understand your model correctly, you could make your own version of an RNNCell
and wrap it in a tf.dynamic_rnn
, but the actual cell
implementation will be a little complex since you need to evaluate a condition dynamically at runtime.
For starters, you can take a look at Tensorflow's dynamic_rnn
code here.
Alternatively, dynamic graphs have never been Tensorflow
's strong suite, so consider using other frameworks like PyTorch
or you can try out eager_execution
and see if that helps.