Search code examples
pythontensorflowtensorflow2.0gradienttape

Tensorflow 2.0 Autograph indirect modification (hidden states) works, when it shouldn't


So, here it says that indirect modification should not work, which means that changes would be invisible (What does invisible change mean anyway?)

But this code computes the gradient correctly:

import tensorflow as tf


class C:
    def __init__(self):
        self.x = tf.Variable(2.0)

    @tf.function
    def change(self):
        self.x.assign_add(2.0)

    @tf.function
    def func(self):
        self.change()
        return self.x * self.x


c = C()
with tf.GradientTape() as tape:
    y = c.func()
print(tape.gradient(y, c.x)) # --> tf.Tensor(8.0, shape=(), dtype=float32)

Am I missing something here?

Thanks


Solution

  • The docs are missing a detail and should be clarified - "invisible" means the change is not detected by AutoGraph's analyzer. Since AutoGraph analyzes one function at a time, modifications made in another function are not visible to the analyzer.

    But, this caveat does not apply to Ops with side effects, such as modifications to TF Variables - those will still be wired correctly in the graph. So your code should work correctly.

    The limitation only applies to some changes made to pure Python objects (lists, dicts, etc.), and is only a problem when using control flow.

    For example, here's a modification of your code that wouldn't work:

    class C:
        def __init__(self):
            self.x = None
    
        def reset(self):
            self.x = tf.constant(10)
    
        def change(self):
            self.x += 1
    
        @tf.function
        def func(self):
          self.reset()
          for i in tf.range(3):
            self.change()
          return self.x * self.x
    
    
    c = C()
    print(c.func())
    

    The error message is rather obscure, but it's the same error that gets raised if you try to access the result of an op created inside the body of a tf.while_loop without using loop_vars:

        <ipython-input-18-23f1641cfa01>:20 func  *
            return self.x * self.x
    
        ... more internal frames ...
    
        InaccessibleTensorError: The tensor 'Tensor("add:0", shape=(),
    dtype=int32)' cannot be accessed here: it is defined in another function or
    code block. Use return values, explicit Python locals or TensorFlow
    collections to access it. Defined in: FuncGraph(name=while_body_685,
    id=5029696157776); accessed from: FuncGraph(name=func, id=5029690557264).