Search code examples
tensorflowbackpropagation

TensorFlow custom gradients


I have a custom gradient calculation function which doubles the incoming gradients.

import tensorflow as tf

@tf.RegisterGradient("CustomSquare")
def _custom_square_grad(op, grad):
  return grad*2.0

c = tf.constant(3.)

s1 = tf.square(c)
grad1 = tf.gradients(s1, c)[0]

g = tf.get_default_graph()
with g.gradient_override_map({"Square": "CustomSquare"}):
    s2 = tf.square(c)
    grad2 = tf.gradients(s2, c)[0]

with tf.Session() as sess:
    print(sess.run([c, s1, grad1]))
    print(sess.run([c, s2, grad2]))

The results I get are surprising:

[3.0, 9.0, 6.0]
[3.0, 9.0, 2.0]

I was expecting the second result to be [3.0, 9.0, 12.0]. What am I missing?

Thanks.


Solution

  • In short, the correct version of _custom_square_grad should be:

    @tf.RegisterGradient("CustomSquare")                                             
    def _custom_square_grad(op, grad):                                               
        x = op.inputs[0]                                                            
        return 2.0 * (grad * 2.0 * x)
    

    In order to understand the code, you need to know how gradient works. When you define tf.RegisterGradient, it is supposed to BACK-PROPAGATE the gradients from outputs to inputs. For tf.squre, the default gradient function is like this:

    # Given y = tf.square(x) => y' = 2x
    grad_x = grad_y * 2.0 * x
    

    Since you want to double the gradient in your customized gradient function, you can simply change it to grad_x = 2.0 * (grad_y * 2.0 * x).