In the official tf.custom_gradient documentation it shows how to define custom gradients for log(1 + exp(x))
@tf.custom_gradient
def log1pexp(x):
e = tf.exp(x)
def grad(dy):
return dy * (1 - 1 / (1 + e))
return tf.math.log(1 + e), grad
When y = log(1 + exp(x))
, analytically the derivative comes out to be dy/dx = (1 - 1 / (1 + exp(x)))
.
However in the code def grad
says its dy * (1 - 1 / (1 + exp(x)))
.
dy/dx = dy * (1 - 1 / (1 + exp(x)))
is not a valid equation. While dx = dy * (1 - 1 / (1 + exp(x)))
is wrong as it should be the reciprocal.
What does the grad
function equate to?
I finally figured it out. The dy
should be called upstream_gradient
or upstream_dy_dx
.
By chain rule we know that
where dx[i]/dx[i+1]
is the gradient of the current function.
So dy
is the product of all the gradients upstream before this function.
So, if you forget to multiply the dy
it is effectively the same as tf.stop_gradient
Here is a code which demos this. Full notebook here
@tf.custom_gradient
def foo(x):
tf.debugging.assert_rank(x, 0)
def grad(dy_dx_upstream):
dy_dx = 2 * x
dy_dx_downstream = dy_dx * dy_dx_upstream
tf.print(f'x={x}\tupstream={dy_dx_upstream}\tcurrent={dy_dx}\t\tdownstream={dy_dx_downstream}')
return dy_dx_downstream
y = x ** 2
tf.print(f'x={x}\ty={y}')
return y, grad
x = tf.constant(2.0, dtype=tf.float32)
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
y = foo(foo(foo(x))) # y = x ** 8
tf.print(f'\nfinal dy/dx={tape.gradient(y, x)}')
Output
x=2.0 y=4.0
x=4.0 y=16.0
x=16.0 y=256.0
x=16.0 upstream=1.0 current=32.0 downstream=32.0
x=4.0 upstream=32.0 current=8.0 downstream=256.0
x=2.0 upstream=256.0 current=4.0 downstream=1024.0
final dy/dx=1024.0