I am working on a regression problem using TensorFlow, where I have encountered a challenge with my loss function. My data points are structured as triples $(Y_i, G_i, X_i)$, where
$Y_i \in \mathbb{R}$ represents an outcome; $G_i \in {0,1}$ is a binary group identifier; $X_i \in \mathbb{R}^d$ is a feature vector.
The goal is to train a neural network that predicts $\hat{Y}$ given $X$, using a custom loss function that is the absolute difference in Mean Squared Errors (MSE) between the two groups.
Formally, the loss function for a prediction algorithm $f \colon \mathbb{R}^d \to \mathbb{R}$ is defined as $$e_k(f) := \frac{\sum_{i \colon G_i=k} (Y_i - f(X_i))^2)}{\sum_i 1{G_i=k}}.$$ The loss is then $|e_0(f) - e_1(f)|$. Alternatively, minimizing $(e_0(f) - e_1(f))^2$ is equivalent.
I am struggling to implement this loss function in TensorFlow, as it is not a straightforward sum of individual data point losses. The challenge lies in the loss depending on multiple data points across both groups.
My main questions are
import numpy as np
import tensorflow as tf
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
def custom_loss(group):
def loss(y_true, y_pred):
# reshape
y_pred = tf.reshape(y_pred, [-1])
y_true = tf.reshape(y_true, [-1])
# Create a mask for each batch
mask_b = tf.equal(group, 1)
mask_r = tf.equal(group, 0)
y_pred_b = tf.boolean_mask(y_pred, mask_b)
y_pred_r = tf.boolean_mask(y_pred, mask_r)
y_true_b = tf.boolean_mask(y_true, mask_b)
y_true_r = tf.boolean_mask(y_true, mask_r)
# Ensure same data type
y_pred_b = tf.cast(y_pred_b, y_true.dtype)
y_pred_r = tf.cast(y_pred_r, y_true.dtype)
mse_b = tf.reduce_mean(tf.square(y_true_b - y_pred_b))
mse_r = tf.reduce_mean(tf.square(y_true_r - y_pred_r))
return tf.abs(mse_b - mse_r)
return loss
# Since the loss depends on the group average, batch_size should be sufficiently large (?)
def train_early_stopping(model, custom_loss,
X_train, y_train, g_train, X_val, y_val, g_val,
n_epoch=500, patience=10, batch_size=1000):
# Initialize variables for early stopping
best_val_loss = float('inf')
wait = 0
best_epoch = 0
for epoch in range(n_epoch):
if epoch == n_epoch-1:
print('Not converged.')
break
loss_epoch_list = []
for step in range(len(X_train) // batch_size):
with tf.GradientTape() as tape:
start = step * batch_size
end = start + batch_size
if end > X_train.shape[0]:
break
else:
X_batch = X_train[start:end]
y_batch = y_train[start:end]
g_batch = g_train[start:end]
y_pred = model(X_batch, training=True)
loss_value = custom_loss(g_batch)(y_batch, y_pred)
loss_epoch_list.append(loss_value.numpy())
grads = tape.gradient(loss_value, model.trainable_variables)
model.optimizer.apply_gradients(zip(grads, model.trainable_variables))
# Calculate validation loss for EarlyStopping
val_loss = custom_loss(g_val)(y_val, model.predict(X_val))
print(f"Epoch {epoch+1}: Train Loss: {np.mean(loss_epoch_list)}, Validation Loss: {val_loss}")
# Early stopping check
if val_loss < best_val_loss:
best_val_loss = val_loss
best_weights = model.get_weights()
wait = 0
best_epoch = epoch
else:
wait += 1
if wait >= patience:
print(f"Early Stopping triggered at epoch {best_epoch + 1}, Validation Loss: {best_val_loss}")
model.set_weights(best_weights)
break
# Create a synthetic dataset
X, y = make_regression(n_samples=20000, n_features=10, noise=0.2, random_state=42)
group = np.random.choice([0, 1], size=y.shape) # 1 for 'b', 0 for 'r'
X_train_full, X_test, y_train_full, y_test, g_train_full, g_test = train_test_split(X, y, group, test_size=0.5, random_state=42)
X_train, X_val, y_train, y_val, g_train, g_val = train_test_split(X_train_full, y_train_full, g_train_full, test_size=0.2, random_state=42)
# main
num_unit = 64
model_fair = tf.keras.Sequential([
tf.keras.layers.Dense(num_unit, activation='relu', input_shape=(X.shape[1],)),
tf.keras.layers.Dense(num_unit, activation='relu'),
tf.keras.layers.Dense(1)
])
model_fair.compile(optimizer='adam')
batch_size = X_train.shape[0]//5
train_early_stopping(model_fair, custom_loss, X_train, y_train, g_train, X_val, y_val, g_val,
patience=10, batch_size=batch_size)
The code executes without generating any error. However, the training, validation, and test set losses differ significantly, indicating potential issues with the training process.
The problem in the initial code was solved by Pawel -- thank you so much! The modified code in the question worked well after making the following changes: