Search code examples
pythontensorflow2.0decoratorgradienttape

Taking gradients when using tf.function


I am puzzled by the behavior I observe in the following example:

import tensorflow as tf

@tf.function
def f(a):
    c = a * 2
    b = tf.reduce_sum(c ** 2 + 2 * c)
    return b, c

def fplain(a):
    c = a * 2
    b = tf.reduce_sum(c ** 2 + 2 * c)
    return b, c


a = tf.Variable([[0., 1.], [1., 0.]])

with tf.GradientTape() as tape:
    b, c = f(a)
    
print('tf.function gradient: ', tape.gradient([b], [c]))

# outputs: tf.function gradient:  [None]

with tf.GradientTape() as tape:
    b, c = fplain(a)
    
print('plain gradient: ', tape.gradient([b], [c]))

# outputs: plain gradient:  [<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
# array([[2., 6.],
#        [6., 2.]], dtype=float32)>]

The lower behavior is what I would expect. How can I understand the @tf.function case?

Thank you very much in advance!

(Note that this problem is distinct from: Missing gradient when using tf.function , since here all calculations are inside the function.)


Solution

  • Gradient tape does not record the operations inside the tf.Graph generated by @tf.function treating the function as a whole. Roughly, f is applied to a, and gradient tape has recorded the gradients of the outputs of f with respect to input a (it is the only watched variable, tape.watched_variables()).

    In the second case, there is no graph generated, and operations are applied in Eager mode. So everything works as expected.

    A good practice is to wrap a most computationally expensive function in the @tf.function (often a training loop). In your case, it will be smth like:

    @tf.function
    def f(a):
        with tf.GradientTape() as tape:
            c = a * 2
            b = tf.reduce_sum(c ** 2 + 2 * c)
        grads = tape.gradient([b], [c])
        print('tf.function gradient: ', grads)
        return grads