Search code examples
kerascallbackearly-stopping

Early Stop callback in keras


How can one effectively stop the fit process of a training model via callback in keras? Thus far I have tried various approaches including the one below.

class EarlyStoppingCallback(tf.keras.callbacks.Callback):
    def __init__(self, threshold):
        super(EarlyStoppingCallback, self).__init__()
        self.threshold = threshold

    def on_epoch_end(self, epoch, logs=None):
        accuracy = logs["accuracy"]
        if accuracy >= self.threshold:
            print("Stopping early!")
            self.model.stop_training = True

The callback is executed, however the self.model.stop_training = True does not seem to have an effect. The print succeeds, but the model continues training. Any idea how to resolve this issue? My tensorflow version is: tensorflow==1.14.0


Solution

  • You're probably affected by the following issue: https://github.com/tensorflow/tensorflow/issues/37587.

    In short - whenever model.predict or model.evaluate are called, model.stop_training is reset to False. I was able to reproduce this behavior using your EarlyStoppingCallback followed by another callback which was calling model.predict on some fixed dataset.

    The workaround is to put callbacks which are calling model.predict or model.evaluate first before any callbacks which might want to set model.stop_training to True. It also looks like the issue was fixed in TF 2.2.