Search code examples
pythontensorflowkerastalos

Keras EarlyStopping callback working inconsistently


For training my neural network model I use Keras' EarlyStopping callback to minimize train time (via talos.utils.early_stopper wrapper):

history = model.fit(x=X_train, 
                    y=y_train, 
                    validation_data=(X_val, y_val), 
                    batch_size=params["batch"], 
                    epochs=params["epoch"], 
                    callbacks=[talos.utils.early_stopper(epochs=params["epoch"], mode='strict', min_delta=0.001)], 
                    verbose=1)

However, I'm noticing that it's working rather inconsistently:

Exhibit A

Epoch 1/42
160/160 [==============================] - 19s 73ms/step - loss: 116.8279 - accuracy: 0.3124 - val_loss: 0.5561 - val_accuracy: 0.3708
Epoch 2/42
160/160 [==============================] - 6s 36ms/step - loss: 0.5676 - accuracy: 0.3440 - val_loss: 0.5564 - val_accuracy: 0.3708
Epoch 3/42
160/160 [==============================] - 6s 35ms/step - loss: 0.5720 - accuracy: 0.3337 - val_loss: 0.5573 - val_accuracy: 0.3708
-> TRAIN STOPS

Exhibit B

Epoch 14/42
160/160 [==============================] - 6s 39ms/step - loss: 0.5664 - accuracy: 0.3501 - val_loss: 0.5581 - val_accuracy: 0.3708
Epoch 15/42
160/160 [==============================] - 6s 39ms/step - loss: 0.5684 - accuracy: 0.3414 - val_loss: 0.5575 - val_accuracy: 0.3708
Epoch 16/42
160/160 [==============================] - 6s 39ms/step - loss: 0.5699 - accuracy: 0.3348 - val_loss: 0.5570 - val_accuracy: 0.3708
Epoch 17/42
160/160 [==============================] - 6s 39ms/step - loss: 0.5686 - accuracy: 0.3415 - val_loss: 0.5567 - val_accuracy: 0.3708
Epoch 18/42
160/160 [==============================] - 6s 38ms/step - loss: 0.5666 - accuracy: 0.3457 - val_loss: 0.5566 - val_accuracy: 0.3708
Epoch 19/42
160/160 [==============================] - 6s 39ms/step - loss: 0.5694 - accuracy: 0.3367 - val_loss: 0.5563 - val_accuracy: 0.3708
Epoch 20/42
160/160 [==============================] - 6s 39ms/step - loss: 0.5671 - accuracy: 0.3418 - val_loss: 0.5562 - val_accuracy: 0.3708
Epoch 21/42
160/160 [==============================] - 6s 39ms/step - loss: 0.5654 - accuracy: 0.3472 - val_loss: 0.5561 - val_accuracy: 0.3708
Epoch 22/42
160/160 [==============================] - 6s 39ms/step - loss: 0.5673 - accuracy: 0.3416 - val_loss: 0.5560 - val_accuracy: 0.3708
Epoch 23/42
160/160 [==============================] - 6s 39ms/step - loss: 0.5637 - accuracy: 0.3542 - val_loss: 0.5560 - val_accuracy: 0.3708
Epoch 24/42
160/160 [==============================] - 6s 39ms/step - loss: 0.5670 - accuracy: 0.3417 - val_loss: 0.5560 - val_accuracy: 0.3708
Epoch 25/42
160/160 [==============================] - 6s 38ms/step - loss: 0.5652 - accuracy: 0.3495 - val_loss: 0.5560 - val_accuracy: 0.3708
Epoch 26/42
160/160 [==============================] - 6s 39ms/step - loss: 0.5663 - accuracy: 0.3454 - val_loss: 0.5560 - val_accuracy: 0.3708
Epoch 27/42
160/160 [==============================] - 6s 38ms/step - loss: 0.5679 - accuracy: 0.3384 - val_loss: 0.5560 - val_accuracy: 0.3708
Epoch 28/42
160/160 [==============================] - 6s 38ms/step - loss: 0.5639 - accuracy: 0.3505 - val_loss: 0.5560 - val_accuracy: 0.3708
Epoch 29/42
160/160 [==============================] - 7s 42ms/step - loss: 0.5636 - accuracy: 0.3515 - val_loss: 0.5559 - val_accuracy: 0.3708
Epoch 30/42
160/160 [==============================] - 6s 39ms/step - loss: 0.5680 - accuracy: 0.3399 - val_loss: 0.5559 - val_accuracy: 0.3708
Epoch 31/42
160/160 [==============================] - 6s 39ms/step - loss: 0.5696 - accuracy: 0.3338 - val_loss: 0.5559 - val_accuracy: 0.3708
Epoch 32/42
160/160 [==============================] - 6s 39ms/step - loss: 0.5705 - accuracy: 0.3321 - val_loss: 0.5559 - val_accuracy: 0.3708
Epoch 33/42
160/160 [==============================] - 6s 40ms/step - loss: 0.5724 - accuracy: 0.3273 - val_loss: 0.5559 - val_accuracy: 0.3708
-> TRAIN STOPS

Why is it on Exhibit B it didn't stop earlier even though it's clearly seen that there's no improvement on val_loss above min_delta? I've looked at talos source and it seems to me that early_stopper is just a wrapper for the callback and everything looks fine. I also noticed that this only tends to happen every time val_loss floats around 0.5559 - 0.5560.

FWIW I'm running this on Colab, utilizing TPU.

Thanks!


Solution

  • For some reason, changing monitor from val_loss to val_accuracy (EarlyStopping(monitor="val_accuracy", min_delta=0.01, patience=2, verbose=1, mode='auto') seems to give a more consistent callback.