Implementing a Custom Loss Function in TensorFlow for Regression Problem with Group-Specific MSEs

I am working on a regression problem using TensorFlow, where I have encountered a challenge with my loss function. My data points are structured as triples $(Y_i, G_i, X_i)$, where
$Y_i \in \mathbb{R}$ represents an outcome; $G_i \in {0,1}$ is a binary group identifier; $X_i \in \mathbb{R}^d$ is a feature vector.

The goal is to train a neural network that predicts $\hat{Y}$ given $X$, using a custom loss function that is the absolute difference in Mean Squared Errors (MSE) between the two groups.

Formally, the loss function for a prediction algorithm $f \colon \mathbb{R}^d \to \mathbb{R}$ is defined as $$e_k(f) := \frac{\sum_{i \colon G_i=k} (Y_i - f(X_i))^2)}{\sum_i 1{G_i=k}}.$$ The loss is then $|e_0(f) - e_1(f)|$. Alternatively, minimizing $(e_0(f) - e_1(f))^2$ is equivalent.

I am struggling to implement this loss function in TensorFlow, as it is not a straightforward sum of individual data point losses. The challenge lies in the loss depending on multiple data points across both groups.

My main questions are

  • How can I structure this loss function in TensorFlow, considering its dependence on group-wise calculations?
  • Are there any specific TensorFlow functions or techniques that would simplify the implementation of such a group-based loss function? Any guidance or suggestions on how to proceed would be greatly appreciated.

What I tried

import numpy as np
import tensorflow as tf
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

def custom_loss(group):
    def loss(y_true, y_pred):
        # reshape
        y_pred = tf.reshape(y_pred, [-1])
        y_true = tf.reshape(y_true, [-1])

        # Create a mask for each batch
        mask_b = tf.equal(group, 1)
        mask_r = tf.equal(group, 0)

        y_pred_b = tf.boolean_mask(y_pred, mask_b)
        y_pred_r = tf.boolean_mask(y_pred, mask_r)
        y_true_b = tf.boolean_mask(y_true, mask_b)
        y_true_r = tf.boolean_mask(y_true, mask_r)

        # Ensure same data type
        y_pred_b = tf.cast(y_pred_b, y_true.dtype)
        y_pred_r = tf.cast(y_pred_r, y_true.dtype)

        mse_b = tf.reduce_mean(tf.square(y_true_b - y_pred_b))
        mse_r = tf.reduce_mean(tf.square(y_true_r - y_pred_r))

        return tf.abs(mse_b - mse_r)
    return loss

# Since the loss depends on the group average, batch_size should be sufficiently large (?)
def train_early_stopping(model, custom_loss,
                         X_train, y_train, g_train, X_val, y_val, g_val,
                         n_epoch=500, patience=10, batch_size=1000):
    # Initialize variables for early stopping
    best_val_loss = float('inf')
    wait = 0
    best_epoch = 0

    for epoch in range(n_epoch):
        if epoch == n_epoch-1:
            print('Not converged.')
        loss_epoch_list = []
        for step in range(len(X_train) // batch_size):
            with tf.GradientTape() as tape:
                start = step * batch_size
                end = start + batch_size
                if end > X_train.shape[0]:
                    X_batch = X_train[start:end]
                    y_batch = y_train[start:end]
                    g_batch = g_train[start:end]

                    y_pred = model(X_batch, training=True)
                    loss_value = custom_loss(g_batch)(y_batch, y_pred)
                grads = tape.gradient(loss_value, model.trainable_variables)
                model.optimizer.apply_gradients(zip(grads, model.trainable_variables))

        # Calculate validation loss for EarlyStopping
        val_loss = custom_loss(g_val)(y_val, model.predict(X_val))
        print(f"Epoch {epoch+1}: Train Loss: {np.mean(loss_epoch_list)}, Validation Loss: {val_loss}")

        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_weights = model.get_weights()
            wait = 0
            best_epoch = epoch
            wait += 1
            if wait >= patience:
                print(f"Early Stopping triggered at epoch {best_epoch + 1}, Validation Loss: {best_val_loss}")

# Create a synthetic dataset
X, y = make_regression(n_samples=20000, n_features=10, noise=0.2, random_state=42)
group = np.random.choice([0, 1], size=y.shape)  # 1 for 'b', 0 for 'r'
X_train_full, X_test, y_train_full, y_test, g_train_full, g_test = train_test_split(X, y, group, test_size=0.5, random_state=42)
X_train, X_val, y_train, y_val, g_train, g_val = train_test_split(X_train_full, y_train_full, g_train_full, test_size=0.2, random_state=42)

# main
num_unit = 64

model_fair = tf.keras.Sequential([
    tf.keras.layers.Dense(num_unit, activation='relu', input_shape=(X.shape[1],)),
    tf.keras.layers.Dense(num_unit, activation='relu'),

batch_size = X_train.shape[0]//5

train_early_stopping(model_fair, custom_loss, X_train, y_train, g_train, X_val, y_val, g_val,
                     patience=10, batch_size=batch_size)

The code executes without generating any error. However, the training, validation, and test set losses differ significantly, indicating potential issues with the training process.


  • The problem in the initial code was solved by Pawel -- thank you so much! The modified code in the question worked well after making the following changes:

    • choose a much smaller batch size (e.g. 64)
    • change the loss from the absolute loss to the squared loss
    • shuffle the train set for each epoch -- this is not necessary but improves the result