Search code examples
pythontensorflowtensorflow2.0gradient-descent

What is the analytic interpretation for Tensorflow custom gradient?


In the official tf.custom_gradient documentation it shows how to define custom gradients for log(1 + exp(x))

@tf.custom_gradient
def log1pexp(x):
  e = tf.exp(x)
  def grad(dy):
    return dy * (1 - 1 / (1 + e))
  return tf.math.log(1 + e), grad

When y = log(1 + exp(x)), analytically the derivative comes out to be dy/dx = (1 - 1 / (1 + exp(x))).

However in the code def grad says its dy * (1 - 1 / (1 + exp(x))). dy/dx = dy * (1 - 1 / (1 + exp(x))) is not a valid equation. While dx = dy * (1 - 1 / (1 + exp(x))) is wrong as it should be the reciprocal.

What does the grad function equate to?


Solution

  • I finally figured it out. The dy should be called upstream_gradient or upstream_dy_dx.

    By chain rule we know that

    chain rule

    where dx[i]/dx[i+1] is the gradient of the current function.

    So dy is the product of all the gradients upstream before this function.

    enter image description here

    So, if you forget to multiply the dy it is effectively the same as tf.stop_gradient

    Here is a code which demos this. Full notebook here

    @tf.custom_gradient
    def foo(x):
        tf.debugging.assert_rank(x, 0)
    
        def grad(dy_dx_upstream):
            dy_dx = 2 * x
            dy_dx_downstream = dy_dx * dy_dx_upstream
            tf.print(f'x={x}\tupstream={dy_dx_upstream}\tcurrent={dy_dx}\t\tdownstream={dy_dx_downstream}')
            return dy_dx_downstream
        
        y = x ** 2
        tf.print(f'x={x}\ty={y}')
        
        return y, grad
    
    
    x = tf.constant(2.0, dtype=tf.float32)
    
    with tf.GradientTape(persistent=True) as tape:
        tape.watch(x)
        y = foo(foo(foo(x))) # y = x ** 8
    
    tf.print(f'\nfinal dy/dx={tape.gradient(y, x)}')
    

    Output

    x=2.0   y=4.0
    x=4.0   y=16.0
    x=16.0  y=256.0
    x=16.0  upstream=1.0    current=32.0        downstream=32.0
    x=4.0   upstream=32.0   current=8.0     downstream=256.0
    x=2.0   upstream=256.0  current=4.0     downstream=1024.0
    
    final dy/dx=1024.0