Search code examples
pythonpython-3.xtensorflowmachine-learningeager-execution

How to reuse the inner gradient in nested gradient tapes?


I am working on a routine in tensorflow 1.15 that evaluates several hessian-vector products for different vectors

def hessian_v_prod(self, v):
    with tf.GradientTape() as t1:
        with tf.GradientTape() as t2:
            # evaluate loss which uses self.variables
            loss_val = self.loss()
        grad = t2.gradient(loss_val, self.variables)
        v_hat = tf.reduce(tf.multiply(v, grad))

    return t1.gradient(v_hat, self.variables)

Each time I call this function it must evaluate the inner loop and calculate the gradient but this is the same regardless of the value of v. How can I reuse the grad value each time I call this function?

I see there is an option to create a tape as tf.GradientTape(persist=True) which keeps the resources for the tape, but can't work out how to incorporate this into my function.


Solution

  • I had to dig into the inner workings of GradientTape but managed to figure it out. Sharing here for anyone else who may have the same problem. Spoiler alert: it's a bit hacky!

    First of all, what is actually happening when calling

    with tf.GradientTape() as tape:
        loss_value = self.loss()
    tape.gradient(loss_value, vars)
    

    To find this out we need to check the __enter__() and __exit__() functions which are called at the start and end of the with block respectively.

    in tensorflow_core/python/eager/backprop.py

    def __enter__(self):
        """Enters a context inside which operations are recorded on this tape."""
        self._push_tape()
        return self
    
    def __exit__(self, typ, value, traceback):
        """Exits the recording context, no further operations are traced."""
        if self._recording:
            self._pop_tape()
    

    We can use these private functions ourselves to control the recording without the need for a with block.

    # Initialize outer and inner tapes
    self.gt_outer = tf.GradientTape(persistent=True)
    self.gt_inner = tf.GradientTape(persistent=True)
    
    # Begin Recording
    self.gt_outer._push_tape()
    self.gt_inner._push_tape()
    
    # evaluate loss which uses self.variables
    loss_val = self.loss()
    
    # stop recording on inner tape
    self.gt_inner._pop_tape()
    
    # Evaluate the gradient on the inner tape
    self.gt_grad = self.gt_inner.gradient(loss_val, self.variables)
    
    # Stop recording on the outer tape
    self.gt_outer._pop_tape()
    

    Now whenever we need to evaluate the hessian vector product we can reuse the outer gradient tape.

    def hessian_v_prod(self, v):
        self.gt_outer._push_tape()
        v_hat = tf.reduce(tf.multiply(v, self.gt_grad))
        self.gt_outer._pop_tape()
        return self.gt_outer.gradient(v_hat, self.variables)
    

    Note that we are persisting the tapes, so every time the hessian vector product is evaluated it uses more memory. There is no way to keep part of the tape memory so at certain points it becomes necessary to reset the tapes.

    # reset tapes
    self.gt_outer._tape = None
    self.gt_inner._tape = None
    

    To use them again after this we need to reevaluate the inner loop. It's not perfect, but it does the job and gives significant speed up (nearly x2) at the cost of greater memory usage.