Search code examples
pythonmachine-learningkerasscikit-learncross-validation

How to get training & validation loss of Keras scikit-learn wrapper in cross validation?


I know that model.fit in keras returns a callbacks.History object where we can get loss and other metrics from it as follows.

...
train_history = model.fit(X_train, Y_train,
                    batch_size=batch_size, nb_epoch=nb_epoch,
                    verbose=1, validation_data=(X_test, Y_test))
loss = train_history.history['loss']
val_loss = train_history.history['val_loss']

However, in my new experimenet I am using cross validation with keras model using kerasclassifier (full example code: https://chrisalbon.com/deep_learning/keras/k-fold_cross-validating_neural_networks/)

# Wrap Keras model so it can be used by scikit-learn
neural_network = KerasClassifier(build_fn=create_network, 
                                 epochs=10, 
                                 batch_size=100, 
                                 verbose=1)

Since now I am using cross validation I am unsure how to get the train and validation loss.


Solution

  • As mentioned explicitly in the documentation, cross_val_score includes a scoring argument, which is

    Similar to cross_validate but only a single metric is permitted.

    hence it cannot be used for returning all the loss & metric info of Keras model.fit().

    The scikit-learn wrapper of Keras is meant as a convenience, provided that you are not really interested in all the underlying details (such as training & validation loss and accuracy). If this is not the case, you should revert to using Keras directly. Here is how you could do that using the example you have linked to and elements of this answer of mine:

    import numpy as np
    from keras import models, layers
    from sklearn.datasets import make_classification
    from sklearn.model_selection import KFold
    
    np.random.seed(0)
    
    # Number of features
    number_of_features = 100
    
    # Generate features matrix and target vector
    features, target = make_classification(n_samples = 10000,
                                           n_features = number_of_features,
                                           n_informative = 3,
                                           n_redundant = 0,
                                           n_classes = 2,
                                           weights = [.5, .5],
                                           random_state = 0)
    
    def create_network():
        network = models.Sequential()
        network.add(layers.Dense(units=16, activation='relu', input_shape=(number_of_features,)))
        network.add(layers.Dense(units=16, activation='relu'))
        network.add(layers.Dense(units=1, activation='sigmoid'))
    
        network.compile(loss='binary_crossentropy', 
                        optimizer='rmsprop', 
                        metrics=['accuracy']) 
    
        return network
    
    n_splits = 3
    kf = KFold(n_splits=n_splits, shuffle=True)
    
    loss = []
    acc = []
    val_loss = []
    val_acc = []
    
    # cross validate:
    for train_index, val_index in kf.split(features):
        model = create_network()
        hist = model.fit(features[train_index], target[train_index],
                         epochs=10,
                         batch_size=100,
                         validation_data = (features[val_index], target[val_index]),
                         verbose=0)
        loss.append(hist.history['loss'])
        acc.append(hist.history['acc'])
        val_loss.append([hist.history['val_loss']])
        val_acc.append(hist.history['val_acc'])
    

    After which, for example loss will be:

    [[0.7251979386058971,
      0.6640552306833333,
      0.6190941931069023,
      0.5602273066015956,
      0.48771809028534785,
      0.40796665995284814,
      0.33154681897220617,
      0.2698465999525444,
      0.227492357244586,
      0.1998490962115201],
     [0.7109123742507104,
      0.674812126485093,
      0.6452083222258479,
      0.6074533335751673,
      0.5627432800365635,
      0.51291748379345,
      0.45645068427406726,
      0.3928780094229408,
      0.3282097149542538,
      0.26993170230619656],
     [0.7191790426458682,
      0.6618405645963258,
      0.6253172250296091,
      0.5855853647883192,
      0.5438901918195831,
      0.4999895181964501,
      0.4495182811042725,
      0.3896359298090465,
      0.3210068798340545,
      0.25932698793518183]]
    

    i.e. a list of n_splits lists (here 3), each one of which contains the training loss for each epoch (here 10). Similarly for the other lists...