Search code examples
pythontensorflowmachine-learningkerasneural-network

Understanding the behavior of Keras EarlyStopping


I'm using tensorflow 2.4.0, and here's the code of the tf.keras EarlyStopping callback, in particular the method, of the EarlyStopping class, called at the end of each epoch (on_epoch_end):

def on_epoch_end(self, epoch, logs=None):
  current = self.get_monitor_value(logs).
  if self.monitor_op(current - self.min_delta, self.best):
    self.best = current
    self.wait = 0
    if self.restore_best_weights:
      self.best_weights = self.model.get_weights()
  else:
    self.wait += 1
    if self.wait >= self.patience:
      self.stopped_epoch = epoch
      self.model.stop_training = True
      if self.restore_best_weights:
        if self.verbose > 0:
          print('Restoring model weights from the end of the best epoch.')
        self.model.set_weights(self.best_weights)

where, since, in my case, the monitored quantity is the val_loss:

self.monitor_op = np.less

In essence, the code performs this logic:

If (current - min_delta) < best:
      best = current;
      wait = 0
Otherwise:
      wait += 1;
      if wait >= patience:
            stop training

Then if, for example:

  • min_delta = 0.1
  • current = 0.9
  • best = 0.85

we have that (current - min_delta) < best, thus:

  • best = current (=0.9)
  • wait = 0

So, best is now associated to a worse value than the previous one (0.9 instead of 0.85); is it the correct/expected behavior of EarlyStopping? It's seems strange


Solution

  • If you take a look at the init method of the EarlyStopping class, you should see something like this:

    if self.monitor_op == np.greater:
     self.min_delta *= 1
    else:
     self.min_delta *= -1
    

    So, since we know self.monitor_op = np.less, I think min_delta is actually -0.1 in your case and the if statement evaluation is something like: (0.9-(-0.1)) < 0.85. I am assuming you have an EarlyStopping callback defined as:

    es = EarlyStopping(monitor='val_loss', mode='min')
    

    Note also what min_delta actually is:

    Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement.