Search code examples
pythontensorflowkerasgradientgradienttape

Is it possible to acquire an intermediate gradient? (Tensorflow)


When using gradient tape you can calculate the gradient after using:

with tf.GradientTape() as tape:
        out = model(x, training=True)
        out = tf.reshape(out, (num_img, 1, 10)) # Resizing 
        loss = tf.keras.losses.categorical_crossentropy(y, out) 
        gradient = tape.gradient(loss, model.trainable_variables)

However, this returns the, in the case of the cifar10 inputs, gradients of the input images. Is there a way to access the gradients of an intermediate step, such that they have been through "some" training?


Solution

  • EDIT: Thanks to your comment I got a better understanding of your problem. The following code is far from ideal and does not take into consideration batch training, etc. but it might give you a good starting point. I wrote a custom training step which basically substitutes the model.fit method. There might be better methods to do this, but it should give you a quick comparison of gradients.

    def custom_training(model, data):
        x, y = data
        # Training 
        with tf.GradientTape() as tape:
            y_pred = model(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = tf.keras.losses.mse(y, y_pred)
            
        trainable_vars = model.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        tf.keras.optimizers.Adam().apply_gradients(zip(gradients, trainable_vars))
        # computing the gradient without optimizing it!
        with tf.GradientTape() as tape:
            y_pred = model(x, training=False)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = tf.keras.losses.mse(y, y_pred)
        trainable_vars = model.trainable_variables
        gradients_plus = tape.gradient(loss, trainable_vars)
        
        return gradients, gradients_plus
    

    Let us assume a very simple model:

    import tensorflow as tf
    
    train_data = tf.random.normal((1000, 32))
    train_features = tf.random.normal((1000,))
    
    inputs = tf.keras.layers.Input(shape=(32))
    hidden_1 = tf.keras.layers.Dense(32)(inputs)
    hidden_2 = tf.keras.layers.Dense(32)(hidden_1)
    outputs = tf.keras.layers.Dense(1)(hidden_2)
    
    model = tf.keras.Model(inputs, outputs)
    

    And you want to compute the gradients of all layers with respect to the inputs. You can use the following:

    with tf.GradientTape(persistent=True) as tape:
        tape.watch(inputs)
        out_intermediate = []
        inputs = train_data
        cargo = model.layers[0](inputs)
        for layer in model.layers[1:]:
            cargo = layer(cargo)
            out_intermediate.append(cargo)
            
    for x in out_intermediate:
        print(tape.gradient(x, inputs))
    

    If you want to compute a custom loss I recommend Customize what happens in Model.fit