Search code examples
tensorflowautomatic-differentiation

Breaking TensorFlow gradient calculation into two (or more) parts


Is it possible to use TensorFlow's tf.gradients() function in parts, that is - calculate the gradient from of loss w.r.t some tensor, and of that tensor w.r.t the weight, and then multiply them to get the original gradient from the loss to the weight?

For example, let W,b be some weights, let x be an input of a network, and let y0 denote labels.

Assume a forward graph such as

h=Wx+b
y=tanh(h)
loss=mse(y-y0)

We can calculate tf.gradients(loss,W) and then apply (skipping some details) optimizer.apply_gradients() to update W.

I then try to extract an intermediate tensor, by using var=tf.get_default_graph().get_tensor_by_name(...), and then calculate two gradients: g1=tf.gradients(loss,var) and g2=tf.gradients(var,W). I would then, by the chain rule, expect the dimensions of g1 and g2 to work out so that I can write g=g1*g2 in some sense, and get back tf.gradients(loss,W).

Unfortunately, this is not the case. The dimensions are incorrect. Each gradient's dimensions will be that of the "w.r.t variable", so there won't be a correspondence between the first gradient and the second one. What am I missing, and how can I do this?

Thanks.


Solution

  • tf.gradients will sum over the gradients of the input tensor. To avoid it you have to split the tensor into scalars and apply tf.gradients to each of them:

    import tensorflow as tf
    
    x = tf.ones([1, 10])
    
    w = tf.get_variable("w", initializer=tf.constant(0.5, shape=[10, 5]))
    out = tf.matmul(x, w)
    out_target = tf.constant(0., shape=[5])
    
    loss = tf.reduce_mean(tf.square(out - out_target))
    
    grad = tf.gradients(loss, x)
    
    part_grad_1 = tf.gradients(loss, out)
    part_grad_2 = tf.concat([tf.gradients(i, x) for i in tf.split(out, 5, axis=1)], axis=1)
    
    grad_by_parts = tf.matmul(part_grad_1, part_grad_2)
    
    init = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        sess.run(init)
        print(sess.run([grad]))
        print(sess.run([grad_by_parts]))