Search code examples
pythontensorflowgradient

How does tf.GradientTape record operations inside the with statement?


I don't understand how tf.GradientTape record operations like y=x**2 inside the "with" statement (following operations).

x = tf.Variable(3.0)

with tf.GradientTape() as tape:
  y = x**2

What Python syntax can be used to achieve this behavior?


Solution

  • EDIT:

    As per the GitHub source code, GradientTape, At Line 897:

     @tf_contextlib.contextmanager
      def _ensure_recording(self):
        """Ensures that this tape is recording."""
        if not self._recording:
          try:
            self._push_tape()
            yield
          finally:
            self._pop_tape()
        else:
          yield
    

    If you don't know, contextmanager triggers whenever with keyword is used. It tells us that it starts keeping track of tape.

    self._pop_tape() is in Line 891:

    def _pop_tape(self):
        if not self._recording:
          raise ValueError("Tape is not recording.")
        tape.pop_tape(self._tape)
        self._recording = False
    

    self._push_tape() is in Line 878:

    def _push_tape(self):
        """Pushes a new tape onto the tape stack."""
        if self._recording:
          raise ValueError("Tape is still recording, This can happen if you try to "
                           "re-enter an already-active tape.")
        if self._tape is None:
          self._tape = tape.push_new_tape(
              persistent=self._persistent,
              watch_accessed_variables=self._watch_accessed_variables)
        else:
          tape.push_tape(self._tape)
        self._recording = True
    

    Here, you can notice tape.push_new_tape is being accessed here which can be found in this source code at Line 43:

    def push_new_tape(persistent=False, watch_accessed_variables=True):
      """Pushes a new tape onto the tape stack."""
      tape = pywrap_tfe.TFE_Py_TapeSetNew(persistent, watch_accessed_variables)
      return Tape(tape)
    

    In this you can see the Tape class just above at Line 31.

    class Tape(object):
      """Represents a gradient propagation trace."""
    
      __slots__ = ["_tape"]
    
      def __init__(self, tape):
        self._tape = tape
    
      def watched_variables(self):
        return pywrap_tfe.TFE_Py_TapeWatchedVariables(self._tape)
    

    Also, I tried to track pywrap_tfe.TFE_Py_TapeSetNew but couldn't find it in this source code of the file.

    Original Answer:

    The documentation of GradientTape states:

    By default GradientTape will automatically watch any trainable variables that are accessed inside the context. If you want fine grained control over which variables are watched you can disable automatic tracking by passing watch_accessed_variables=False to the tape constructor

    With the following code:

    x = tf.Variable(2.0)
    w = tf.Variable(5.0)
    with tf.GradientTape(
        watch_accessed_variables=False, persistent=True) as tape:
      tape.watch(x)
      y = x ** 2  # Gradients will be available for `x`.
      z = w ** 3  # No gradients will be available as `w` isn't being watched.
    dy_dx = tape.gradient(y, x)
    
    print(dy_dx)
    >>> tf.Tensor(4.0, shape=(), dtype=float32)
    
    # No gradients will be available as `w` isn't being watched.
    dz_dw = tape.gradient(z, w)
    
    print(dz_dw)
    >>> None