EarlyStopping based on convergence of a trainable variable in TF/Keras

Suppose I have a custom layer which computes the loss for me, using external trainable variables using TF 2.4 (and yes, I know it's a silly example and loss, it is just for reproducibility, the actual loss is much more complex):

import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense, Layer, Input
from tensorflow.keras import Model
from tensorflow.keras.callbacks import EarlyStopping
import tensorflow as tf

n_col = 10
n_row = 1000
X = np.random.normal(size=(n_row, n_col))
beta = np.arange(10)
y = X @ beta

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

class MyLoss(Layer):
    def __init__(self, var1, var2):
        super(MyLoss, self).__init__()
        self.var1 = tf.Variable(var1)
        self.var2 = tf.Variable(var2)

    def get_vars(self):
        return self.var1, self.var2

    def custom_loss(self, y_true, y_pred):
        return self.var1 ** 2 * tf.math.reduce_mean(tf.math.square(y_true-y_pred)) + self.var2 ** 2

    def call(self, y_true, y_pred):
        self.add_loss(self.custom_loss(y_true, y_pred))
        return y_pred

inputs = Input(shape=(X_train.shape[1],))
y_input = Input(shape=(1,))
hidden1 = Dense(10)(inputs)
output = Dense(1)(hidden1)
my_loss = MyLoss(0.5, 0.5)(y_input, output) # here can also initialize those var1, var2
model = Model(inputs=[inputs, y_input], outputs=my_loss)

model.compile(optimizer= 'adam')

Training this model is simple:

history =[X_train, y_train], None,
                    batch_size=32, epochs=100, validation_split=0.1, verbose=0,
                    callbacks=[EarlyStopping(monitor='val_loss', patience=5)])

And if we write a custom Callback or train epoch by epoch we can see how var1 and var2 converge to 0 as would be expected:

var1_list = []
var2_list = []
for i in range(100):
    if i % 10 == 0:
        print('step %d' % i)[X_train, y_train], None,
              batch_size=32, epochs=1, validation_split=0.1, verbose=0)
    var1, var2 = model.layers[-1].get_vars()

plt.plot(var1_list, label='var1')
plt.plot(var2_list, 'r', label='var2')

Short question: how do I make the model stop (EarlyStopping with some patience) according to the convergence of var1 and var2 (i.e. their vector size, self.var1**2 + self.var2**2, and again assume the loss is much more complex and you cannot just add this vector size to the loss)?

Longer question: (if you have the time/patience)

  • Is it possible to implement a custom Metric and make EarlyStopping track it?
  • In which case how would you make EarlyStopping focus on "convergence" when all its got is mode "min" or "max"? (I wonder could we extend EarlyStopping instead of extending Callback)
  • Can we do this without a metric, with a custom Callback?
  • How would we combine the custom loss above, telling EarlyStopping to pay attention to both, i.e. "stop if you don't see improvement in loss AND improvement in convergence for patience=10"?


  • Well at least for the "shorter question" this turned out quite simple, following this example from TF docs, implementing EarlyStopping with the twist of focusing on the variables norm:

    class EarlyStoppingAtVarsConvergence(tf.keras.callbacks.Callback):
        def __init__(self, norm_thresh=0.01, patience=0):
            super(EarlyStoppingAtVarsConvergence, self).__init__()
            self.norm_thresh = norm_thresh
            self.patience = patience
        def on_train_begin(self, logs=None):
            # The number of epoch it has waited when norm hasn't converged.
            self.wait = 0
            # The epoch the training stops at.
            self.stopped_epoch = 0
            # Initialize sigmas norm.
            self.vars_norm = self.get_vars_norm()
        def get_vars_norm(self):
            var1, var2 = model.layers[-1].get_vars()
            return var1**2 + var2**2
        def on_epoch_end(self, epoch, logs=None):
            current_norm = self.get_vars_norm()
            if np.abs(current_norm - self.vars_norm) > self.norm_thresh:
                self.sigmas_norm = current_norm
                self.wait = 0
                self.wait += 1
                if self.wait >= self.patience:
                    self.stopped_epoch = epoch
                    self.model.stop_training = True
        def on_train_end(self, logs=None):
            if self.stopped_epoch > 0:
                print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))

    Then the model would be run with:

    history =[X_train, y_train], None,
                        batch_size=32, epochs=100, validation_split=0.1, verbose=0,