Search code examples
pythontensorflowloss-functionregularizedgradienttape

TensorFlow: Calculating gradients of regularization loss terms dependent on model input and output


Overview

My model is an encoder that has input Z and output x.

I'm trying to use a total_loss that has both traditional supervised learning and regularization term(s). I have additional functions (outside the network) that use the input Z and the predicted output x_pred to calculate their respective regularization terms to include in the loss calculation.

# Custom training function within model class
def train_step(self, Z, x):
    # Define loss object
    loss_object = tf.keras.losses.MeanSquaredError()
    with tf.GradientTape() as tape:
        # Get encoder output
        x_pred = self.encoder(Z)

        # Calculate traditional supervised learning data loss
        data_loss = loss_object(x, x_pred)

        # Calculate regularization terms
        x_hat, Z_pred = calc_reg_terms(x_pred, Z) # physics-informed function
        # Calculate respective regularization losses
        loss_x = loss_object(x, x_hat)
        loss_z = loss_object(Z, Z_pred)

    """<Additional Code>"""

Question

What is the correct method for calculating the gradient of my total_loss?

In the past, I've tried simply adding all the loss terms together, then taking the gradient of the summed loss.

### PAST METHOD ###
# Calculate total loss
total_loss = data_loss + a * loss_x + b * loss_z  # a,b -> set hyperparameters
# Get gradients
grads = tape.gradient(total_loss, self.trainable_weights)

However, since my loss_x and loss_z are defined outside the encoder, I fear that these losses act more as a bias to the total_loss calculation because the model is actually performing worse when these losses are added to data_loss. The data_loss term has a clear connection to the trainable weights of the encoder, making for a clear gradient calculation, but the same cannot easily be said for my regularization loss terms.

NOTE: Tracking each of these three losses during training shows that data_loss can decrease with each passing training epoch, but both loss_x and loss_z tend to plateau early on during training, hence the fear they act more as an unwanted bias to the total_loss.

What is the proper way to then calculate the gradients with the data_loss, loss_x, and loss_z terms?


Solution

  • Thanks for the clarification in your comment, it makes sense.

    Your code looks correct to me -- that is the general approach. Calculate total_loss = data_reconstruction_loss + constant * regularization_loss, then calculate the gradient on the total_loss, and backpropagate. A simple way to make sure that it's working without doing a full hyperparameter sweep is to set a=0 and b=0, then gradually increase a from some very small value (e.g., a=1E-10) to a large value (e.g., a=1). You can take big steps, but you should see your train and validation loss change as you sweep across values of a. You can then repeat the same process with b. If everything works out, continue to the hyperparameter sweep.