Search code examples
pythontensorflowtensorflow2.x

Computation of second derivatives with batch_jacobian in tensorflow is really slow during training


I am trying to compute the Hessian of the output of a neural network with respect to its inputs. To give you an idea, this is the matrix I am trying to compute:

enter image description here

I am running Tensorflow 2.5.0 and my code to calculate the M matrix looks like this:

def get_Mass_Matrix(self, q, dq):
    nDof = dq.shape[1]
    with tf.GradientTape(persistent = True) as t2:
        t2.watch(dq)
        with tf.GradientTape(persistent = True) as t1:
            t1.watch(dq)
            T = self.kinetic(q, dq)
            
        g = t1.gradient(T, dq)
    h = t2.batch_jacobian(g, dq)
        
    return h 

The function self.kinetic() calls a multilayer perceptron. When I compute M like this, I get the correct answer but my neural network training slows down significantly, even when running on a GPU.

I was wondering if there is a more efficient way to perform the same computation that doesn't result in so much overhead? Thank you.

For reference, I am using the subclassing approach to build the model (it inherits from tf.keras.Model).

Edit:

Adding more details about the self.kinetic function:

def kinetic(self, q, qdot):
    nDof = q.shape[1]
    qdq = tf.concat([tf.reshape(q, ((-1, nDof))),
                      tf.reshape(qdot, ((-1, nDof)))], axis = -1)
    
    return self.T_layers(qdq)

T_layers is defined as:

    self.T_layers = L(nlayers = 4, n = 8, input_dim = (latent_dim, 1), nlact = 'swish', oact = 'linear')

Which is calling:

class L(tf.keras.layers.Layer):

    def __init__(self, nlayers, n, nlact, input_dim, oact = 'linear'):

        super(L, self).__init__()

        self.layers = nlayers
        self.dense_in = tf.keras.layers.Dense(n, activation = nlact, input_shape = input_dim)
        self.dense_lays = []

        for lay in range(nlayers):
            self.dense_lays.append(tf.keras.layers.Dense(n, activation = nlact, kernel_regularizer = 'l1'))

        self.dense_out = tf.keras.layers.Dense(1, activation = oact, use_bias = False)

    def call(self, inputs):
        x = self.dense_in(inputs)
        for lay in range(self.layers):
            x = self.dense_lays[lay](x)

        return self.dense_out(x)

I suspect part of the problem might be that I am not "building" the layers? Any advice is appreciated!


Solution

  • In order to get a reasonable performance from tensorflow, especially when computing gradients, you have to decorate your get_Mass_Matrix with @tf.function to make sure it runs in graph mode. To do this, everything inside the function have to be graph-mode compatible.

    In the call function of class L, it is better to iterate the list directly instead of indexing it, i.e.:

    class L(tf.keras.layers.Layer):
        ...
        def call(self, inputs):
            x = self.dense_in(inputs)
            for l in self.dense_lays:
                x = l(x)
    
            return self.dense_out(x)
    

    Then, you can decorate your get_Mass_Matrix.

    @tf.function
    def get_Mass_Matrix(self, q, dq):
        with tf.GradientTape() as t2:
            t2.watch(dq)
            with tf.GradientTape() as t1:
                t1.watch(dq)
                T = self.kinetic(q, dq)    
            g = t1.gradient(T, dq)
        return t2.batch_jacobian(g, dq) 
    

    Remark: q and dq that are passed into get_Mass_Matrix must be tensors of constant shape(constant between calls), otherwise, it will retrace every time there is a new shape and slow down instead.