Search code examples
tensorflowkerashdf5aucearly-stopping

Early stopping based on AUC


I am fairly new to ML and am currently implementing a simple 3D CNN in python using tensorflow and keras. I want to optimize based on the AUC and would also like to use early stopping/save the best network in terms of AUC score. I have been using tensorflow's AUC function for this as shown below, and it works well for the training. However, the hdf5 file is not saved (despite the checkpoint save_best_only=True) and hence I cannot get the best weights for the evaluation.

Here are the relevant lines of code:

model.compile(loss='binary_crossentropy',
              optimizer=keras.optimizers.Adam(lr=lr),
              metrics=[tf.keras.metrics.AUC()]) 

model.load_weights(path_weights)

filepath = mypath

check = tf.keras.callbacks.ModelCheckpoint(filepath, monitor=tf.keras.metrics.AUC(), save_best_only=True,
                                           mode='auto')

earlyStopping = tf.keras.callbacks.EarlyStopping(monitor=tf.keras.metrics.AUC(), patience=hyperparams['pat'],mode='auto') 

history = model.fit(X_trn, y_trn,
                        batch_size=bs,
                        epochs=n_epochs,
                        verbose=1,
                        callbacks=[check, earlyStopping],
                        validation_data=(X_val, y_val),
                        shuffle=True)

Interestingly, if I only change monitor='val_loss' in the early stopping and checkpoint (not the 'metrics' in model.compile), the hdf5 file is saved but obviously gives the best result in terms of validation loss. I have also tried using mode='max' but the problem is the same. I would very much appreciate your advise, or any other constructive ideas how to work around this problem.


Solution

  • Turns out that even if you add a non-keyword metric, you still need to use its handle to refer to in when you want to monitor it. In your case you can do this:

    auc = tf.keras.metrics.AUC()  # instantiate it here to have a shorter handle
    
    model.compile(loss='binary_crossentropy',
                  optimizer=keras.optimizers.Adam(lr=lr),
                  metrics=[auc]) 
    
    ...
    
    check = tf.keras.callbacks.ModelCheckpoint(filepath,
                                               monitor='auc',  # even use the generated handle for monitoring the training AUC
                                               save_best_only=True,
                                               mode='max')  # determine better models according to "max" AUC.
    

    if you want to monitor the validation AUC (which makes more sense), simply add val_ in the beginning of the handle:

    check = tf.keras.callbacks.ModelCheckpoint(filepath,
                                               monitor='val_auc',  # validation AUC
                                               save_best_only=True,
                                               mode='max')
    

    Another problem is that you ModelCheckpoint is saving the weights based on the minimum AUC instead of the max, which you want.

    This can be changed by setting mode='max'.


    What does mode='auto' do?

    This setting essentially checks if the argument of monitor contains 'acc' and sets it to max. In any other case it sets uses mode='min', which is what is happening in your case.

    You can confirm this here