Search code examples
tensorflowmathematical-optimizationderivativehessian-matrix

Compute Hessian matrix (only diagonal part) with respect to a high rank tensor


I would like to compute the first and the second derivatives (diagonal part of Hessian) of my specified Loss with respect to each feature map of a vgg16 conv4_3 layer's kernel which is a 3x3x512x512 dimensional matrix. I know how to compute derivatives if it is respected to a low-rank one according to How to compute all second derivatives (only the diagonal of the Hessian matrix) in Tensorflow? However, when it turns to higher-rank, I got completed lost.

# Inspecting variables under Ipython notebook
In  : Loss 
Out : <tf.Tensor 'local/total_losses:0' shape=() dtype=float32>

In  : conv4_3_kernel.get_shape() 
Out : TensorShape([Dimension(3), Dimension(3), Dimension(512), Dimension(512)])

## Compute derivatives
Grad = tf.compute_gradients(Loss, conv4_3_kernel)
Hessian = tf.compute_gradients(Grad, conv4_3_kernel)

In  : Grad 
Out : [<tf.Tensor 'gradients/vgg/conv4_3/Conv2D_grad/Conv2DBackpropFilter:0' shape=(3, 3, 512, 512) dtype=float32>]

In  : Hessian 
Out : [<tf.Tensor 'gradients_2/vgg/conv4_3/Conv2D_grad/Conv2DBackpropFilter:0' shape=(3, 3, 512, 512) dtype=float32>]

Please help me to check my understandings. So, for conv4_3_kernel, each dim stand for [Kx, Ky, in_channels, out_channels], so Grad should be partial derivatives of Loss with respect to each element(pixel) in the each feature maps. And Hessian is the second derivatives.

But, Hessian computes all the derivatives, how can I only compute only the diagonal part? should I use tf.diag_part()?


Solution

  • tf.compute_gradients computes derivative of a scalar quantity. If the quantity provided isn't scalar, it turns it into scalar by summing up the components which is what's happening in your example

    To compute full Hessian you need n calls to tf.gradients, The example is here. If you want just the diagonal part, then modify arguments to ith call to tf.gradients to differentiate with respect to ith variable, rather than all variables.