Search code examples
pythontensorflowtensorflow-probability

Tensorflow probability - Bijector training


I have been trying to follow the example from this tutorial, but I am having trouble training any of the variables.

I wrote a small example, but I haven't been able to make that work either:

# Train a shift bijector
shift = tf.Variable(initial_value=tf.convert_to_tensor([1.0], dtype=tf.float32), trainable=True, name='shift_var')
bijector = tfp.bijectors.Shift(shift=shift)

# Input
x = tf.convert_to_tensor(np.array([0]), dtype=tf.float32)
target = tf.convert_to_tensor(np.array([2]), dtype=tf.float32)

optimizer = tf.optimizers.Adam(learning_rate=0.5)
nsteps = 1

print(bijector(x).numpy(), bijector.shift)
for _ in range(nsteps):

    with tf.GradientTape() as tape:
        out = bijector(x)
        loss = tf.math.square(tf.math.abs(out - target))
        #print(out, loss)
    
        gradients = tape.gradient(loss, bijector.trainable_variables)
    
    optimizer.apply_gradients(zip(gradients, bijector.trainable_variables))
    
print(bijector(x).numpy(), bijector.shift)

For nsteps = 1, the two print statements result in the following output:

[1.] <tf.Variable 'shift_var:0' shape=(1,) dtype=float32, numpy=array([1.], dtype=float32)>
[1.] <tf.Variable 'shift_var:0' shape=(1,) dtype=float32, numpy=array([1.4999993], dtype=float32)>

It seems like the bijector still uses the original shift even though the printed value of bijector.shift has been updated.

I cannot increase nsteps as the gradient is None after the first iteration, and I got this error:

ValueError: No gradients provided for any variable: ['shift_var:0'].

I'm using

tensorflow version 2.3.0
tensorflow-probability version 0.11.0

I also tried it on a colab notebook, so I doubt it's a version problem.


Solution

  • You found a bug. The bijector forward function weakly caches the result->input mapping to make downstream inverses and log-determinants fast. But somehow this is also interfering with the gradient. A workaround is adding a del out, as in https://colab.research.google.com/gist/brianwa84/04249c2e9eb089c2d748d05ee2c32762/bijector-cache-bug.ipynb