Search code examples
kerasscikit-learngridsearchcvk-fold

Kfold validation_data in Keras model.fit using Sci-Kit Learn GridsearchCV


I'm working with Keras, using Sci-Kit Learn gridsearchcv and Kold and SciKeras wrappers. I would to pass the validation folders of Kfold to the fit method of the model, by means of the parameter validation_data. I tried some alternatives but I can't do it. Here's the code.

NN = KerasClassifier(
  model=get_NN,
  X_len = len(X_train.columns),
  loss="mse",
  optimizer="SGD",
  epochs=300,
  batch_size=4,
  shuffle=True,
  verbose=False,
  # fit__validation_data = # Here I should pass the validation data
  callbacks=[
    tf.keras.callbacks.EarlyStopping(
      monitor="val_loss", min_delta=0.0001, patience=15, restore_best_weights=True
    )
  ]
)

custom_scores_monk = {
    "accuracy": "accuracy",
    "mse": make_scorer(mean_squared_error,greater_is_better=False)
}

NN_MONK1_GRID_DICT = {
  "model__lr" : [0.5], 
  "model__alpha" : [0.8],
  "model__hidden_activation" : ["tanh"],
  "model__neurons" : [4], 
  "model__initializer" : ["glorot"], 
  "model__nesterov" : [True], 
  "model__penalty": [None], 
  "model__lambda_reg": [None],
  "model__seed" : [15]
}

grid = GridSearchCV(NN,
                    param_grid=NN_MONK1_GRID_DICT,
                    scoring=custom_scores_monk,
                    refit="mse",
                    cv=CV,
                    return_train_score=True,
                    n_jobs=-1
        )

Between the others alternatives, I tried writing a custom callback for updating the data set on_train_begin, but It seems to be a dirty practice, I'm not surprised It doesn't work.

class ValidationCallback(Callback):
  def __init__(self, X, y, validation_split):
    super().__init__()
    self.X = X
    self.y = y
    self.validation_split = validation_split
    self.count = 0

  def on_train_begin(self, logs=None):
    print("Training " + str(self.count))
    indexes = self.validation_split[self.count]
    X_val, y_val = [self.X.iloc[i] for i in indexes], [self.y.iloc[i] for i in indexes]
    self.count = self.count+1
    self.model.fit__validation_data = (X_val, y_val)

Instead, I'm very surprised there is no a solution for a so common task as the KFold cross validation, especially using framework as skl. In particular, this problem leads to the impossibility to use 'val_loss' as monitor value for early stopping, apart from the impossibility to plot and compare training and validation learning curves.

Do You have solutions?


Solution

  • I spent about a week on that and I finally found a way.

    Short answer: don't do it. Just handwrite an ad-hoc method for grid search and use it.

    Long answer: you can define a subclass of the SciKeras' wrapper, in order to redefine the fit method passing the current fold to it. To do that, you must:

    1. know the fold that will be used, and thus you must set a random_state in your CV object
        # define a split strategy using a random_state
        CV = StratifiedKFold(n_splits=5, random_state=42, shuffle=True)
        
        # get the validation folds
        val_split = [ test for (train, test) in CV.split(X_train, y_train) ]
        
        val_data = [ 
          (
            [X_train.iloc[i].tolist() for i in indexes], 
            [y_train.iloc[i].tolist() for i in indexes]
          ) for indexes in val_split 
        ]
    
    1. define a "static" counter for the folds
        # static fold counter
        def count():
          count.count += 1
          return count.count
        
        def reset_counter():
          count.count =-1
        
        def get_count():
          return count.count
    
    1. in the same way, you must define a registry for memorizing the various history objects
        # static history register
        def histories():
          histories.histories = []
        
        def register(h):
          histories.histories.append(h)
        
        def get_histories():
          return histories.histories
        
        def clear_histories():
          histories()
    
    1. define a method for computing the mean of the histories. It allows to do early stopping on the validation loss.
        # utilities to get the mean of K histories
        
        def add_padding(ls, n):
          ls.extend([ls[-1]] * n)
          return ls
        
        def mean_epochs(l):
          return int(mean([ len(item['loss']) for item in l ]))
        
        def mean_history(_histories):
          m = mean_epochs(_histories)+1
          for history in _histories:
            l = len(history['loss'])
            for field in _histories[0]:
              if l>= m:
                history[field] = history[field][:m]
              else:
                history[field] = add_padding(history[field], (m-l))
          return \
            { field : 
                [ 
                  (sum(x)/len(_histories)) for x in zip(
                    *[ history[field] for history in _histories ]
                  )
                ] for field in _histories[0]
            }
    
    1. extend the SciKeras wrapper class, redefining the fit method
        # KerasClassifier Wrapper for kfold
        class KCWrapper(KerasClassifier):
        
          # you can pass the same parameters you passed to the KerasClassifier, after val_data and k
          def __init__(self, val_data, k, *args, **kwargs):
            super(KCWrapper, self).__init__(*args, **kwargs)
            self.val_data = val_data
            self.k = k
          
          def fit(self, X, y, **kwargs):
            h = super().fit(X, y, validation_data=self.val_data[count()], **kwargs)
            register(h.history_)
            # do_NN_plot(h.history_)  # plot single fold curve
            if self.kfold_finished(): # plot mean of k folds curves
              do_NN_plot(mean_history(get_histories()))
            
          def kfold_finished(self):
            return self.k == get_count()+1
    
    1. instantiate the (wrapper of the wrapper of the) classifier
        # Define grids for gridsearchcv
        kerasClassifierParams = {
          "model" : get_NN,
          "X_len" : len(X_train.columns),
          "loss" : "mse",
          "optimizer" : "SGD",
          "epochs" : 300,
          "batch_size" : 4,
          "shuffle" : True,
          "verbose" : False
        }
    
        NN = KCWrapper(
          val_data,
          5, # 5-Fold
          callbacks=[
            tf.keras.callbacks.EarlyStopping(
              monitor="val_loss", min_delta=0.0001, patience=20, restore_best_weights=True
            )
          ],
          **kerasClassifierParams
        )
    

    The provided code also uses a routine for plotting data:

        def do_NN_plot(history):
        
          # Plot Accuracy
          plt.plot(history['binary_accuracy'])
          plt.plot(history['val_binary_accuracy'], linestyle="--", color="orange")
          plt.title(f'model accuracy')
          plt.ylabel('accuracy')
          plt.xlabel('epoch')
          plt.legend(['training', 'test'], loc='lower right')
          plt.show()
        
          # Plot loss
          plt.plot(history['loss'])
          plt.plot(history['val_loss'],  linestyle="--", color="orange")
          plt.title(f'model MSE')
          plt.ylabel('MSE')
          plt.xlabel('epoch')
          plt.legend(['training', 'test'], loc='upper right')
          plt.show()
    

    If you're working on a regression task, you can do the same thing with a wrapper of (a wrapper of) a regressor:

        # define a split strategy using a random_state
        CV = StratifiedKFold(n_splits=5, random_state=42, shuffle=True)
        
        # get the validation folds
        val_split = [ test for (train, test) in CV.split(X_train, y_train) ]
        
        val_data = [ 
          (
            [X_train.iloc[i].tolist() for i in indexes], 
            [y_train.iloc[i].tolist() for i in indexes]
          ) for indexes in val_split 
        ]
    
        # static fold counter
        def count():
          count.count += 1
          return count.count
        
        def reset_counter():
          count.count =-1
        
        def get_count():
          return count.count
    
        # static history register
        def histories():
          histories.histories = []
        
        def register(h):
          histories.histories.append(h)
        
        def get_histories():
          return histories.histories
        
        def clear_histories():
          histories()
    
        # utilities to get the mean of K histories
        
        def add_padding(ls, n):
          ls.extend([ls[-1]] * n)
          return ls
        
        def mean_epochs(l):
          return int(mean([ len(item['loss']) for item in l ]))
        
        def mean_history(_histories):
          m = mean_epochs(_histories)+1
          for history in _histories:
            l = len(history['loss'])
            for field in _histories[0]:
              if l>= m:
                history[field] = history[field][:m]
              else:
                history[field] = add_padding(history[field], (m-l))
          return \
            { field : 
                [ 
                  (sum(x)/len(_histories)) for x in zip(
                    *[ history[field] for history in _histories ]
                  )
                ] for field in _histories[0]
            }
    
        def do_NN_plot(history):
        
          # Plot Accuracy
          plt.plot(history['binary_accuracy'])
          plt.plot(history['val_binary_accuracy'], linestyle="--", color="orange")
          plt.title(f'model accuracy')
          plt.ylabel('accuracy')
          plt.xlabel('epoch')
          plt.legend(['training', 'test'], loc='lower right')
          plt.show()
        
          # Plot loss
          plt.plot(history['loss'])
          plt.plot(history['val_loss'],  linestyle="--", color="orange")
          plt.title(f'model MSE')
          plt.ylabel('MSE')
          plt.xlabel('epoch')
          plt.legend(['training', 'test'], loc='upper right')
          plt.show()
    
    
        # KerasRegressor Wrapper for kfold
        class KRWrapper(KerasRegressor):
        
          def __init__(self, val_data, k, *args, **kwargs):
            super(KRWrapper, self).__init__(*args, **kwargs)
            self.val_data = val_data
            self.k = k
            
          def fit(self, X, y, **kwargs):
            h = super().fit(X, y, validation_data=self.val_data[count()], **kwargs)
            register(h.history_)
            # do_NN_plot(h.history_)  # plot single fold curve
            if self.kfold_finished(): # plot mean of k folds curves
              do_NN_plot(mean_history(get_histories()))
            
          def kfold_finished(self):
            return self.k == get_count()+1
    
        # Define grids for gridsearchcv
        kerasRegressorParams = {
          "model" : get_NN,
          "X_len" : len(X_train.columns),
          "loss" : mee_NN,
          "optimizer" : "SGD", # fixed into get_NN
          "batch_size" : 32,
          "epochs" : 2000,
          "shuffle" : True,
          "verbose" : 0
        }
        
        NN = KRWrapper(
          val_data,
          5,
          callbacks=[
            tf.keras.callbacks.EarlyStopping(
              monitor="val_loss", min_delta=0.000001, patience=50, restore_best_weights=True
            )
          ],
          **kerasRegressorParams
        )
    
    

    This satisfied my curiosity and stubbornness, but it's a dirty solution (even if it's still a solution :P). How I said at the beginning: just handwrite an ad-hoc method for grid search and use it. The solution presented above doesn't allow to use the intrinsic parallelization of the Skl's GridsearchCV, so it's a lot of completely useless work.

    Note: the approach that uses the callback didn't work because the parameters of the fit method are passed before the callback is invoked. Thus when the callback is invoked, the setted fit__validation_data is no evaluated.