Search code examples
pythontensorflowdeep-learninghessian-matrix

Use tf.gradients or tf.hessians on flattened parameter tensor


Let's say I want to compute the Hessian of a scalar-valued function with respect to some parameters W (e.g the weights and biases of a feed-forward neural network). If you consider the following code, implementing a two-dimensional linear model trained to minimize a MSE loss:

import numpy as np
import tensorflow as tf

x = tf.placeholder(dtype=tf.float32, shape=[None, 2])  #inputs
t = tf.placeholder(dtype=tf.float32, shape=[None,])  #labels
W = tf.placeholder(np.eye(2), dtype=tf.float32)  #weights

preds = tf.matmul(x, W)  #linear model
loss = tf.reduce_mean(tf.square(preds-t), axis=0) #mse loss

params = tf.trainable_variables() 
hessian = tf.hessians(loss, params)

you'd expect session.run(tf.hessian,feed_dict={}) to return a 2x2 matrix (equal to W). It turns out that because paramsis a 2x2 tensor, the output is rather a tensor with shape [2, 2, 2, 2]. While I can easily reshape the tensor to obtain the matrix I want, it seems that this operation might be extremely cumbersome when paramsbecomes a list of tensors of varying size (i.e when the model is a deep neural network for instance).

It seems that are two ways around this:

  • Flatten params to be a 1D tensor called flat_params:

    flat_params = tf.concat([tf.reshape(p, [-1]) for p in params])
    

    so that tf.hessians(loss, flat_params) naturally returns a 2x2 matrix. However as noted in Why does Tensorflow Reshape tf.reshape() break the flow of gradients? for tf.gradients (but also holds for tf.hessians), tensorflow is not able to see the symbolic link in the graph between paramsand flat_params and tf.hessians(loss, flat_params) will raise an error as the gradients will be seen as None.

  • In https://afqueiruga.github.io/tensorflow/2017/12/28/hessian-mnist.html, the author of the code goes the other way, and first create the flat parameter and reshapes its parts into self.params. This trick does work and gets you the hessian with its expected shape (2x2 matrix). However, it seems to me that this will be cumbersome to use when you have a complex model, and impossible to apply if you create your model via built-in functions (like tf.layers.dense, ..).

Is there no straight-forward way to get the Hessian matrix (as in the 2x2 matrix in this example) from tf.hessians, when self.params is a list of tensor of arbitrary shapes? If not, how can you automatize the reshaping of the output tensor of tf.hessians?


Solution

  • It turns out (per TensorFlow r1.13) that if len(xs) > 1, then tf.hessians(ys, xs) returns tensors corresponding to only the block diagonal submatrices of the full Hessian matrix. Full story and solutions in this paper https://arxiv.org/pdf/1905.05559, and code at https://github.com/gknilsen/pyhessian