Search code examples
kerascallbackpipelinewrappercross-validation

How can I use a keras callback in a sklearn pipeline?


I am trying to create a simple multy-layer perceptron (MLP) using Keras. In order to avoid data leakage I am using a pipeline in a cross-validation routine.

To do that I have to use a keras wrapper; everything is working fine unless I do not put a TensorBoard callback into the wrapper. I read tons of stackoverflow answers and it looks that my code is correct but I get the following error:

> RuntimeError: Cannot clone object <tensorflow.python.keras.wrappers.scikit_learn.KerasClassifier object at 0x00000245DD5C2A60>, as the constructor either does not set or modifies parameter callbacks

Below my code:

#Network and training parameters
EPOCHS = 100
BATCH_SIZE = 16
VERBOSE = 0
INPUT_SHAPE = (Xtr.shape[1],)
OUTPUT_SHAPE = 1 #number of outputs
N_HIDDEN = 8


def build_mlp(n_hidden, input_shape, output_shape):
    #Build the model
    model = tf.keras.models.Sequential()
    model.add(keras.layers.Dense(units = n_hidden,
                                 input_shape = input_shape,
                                 name = 'dense_layer_1',
                                 activation = 'relu'))
    model.add(keras.layers.Dense(units = output_shape,
                                 name ='output_layer',
                                 activation = 'sigmoid'))
    model.compile(optimizer='Adam',
                 loss='binary_crossentropy',
                 metrics=['accuracy'])
    return model

#TensorBoard
import datetime
LOG_DIR = "logs/MLP_anomaly/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
CALLBACKS = [tf.keras.callbacks.TensorBoard(log_dir = LOG_DIR)]

#create a wrapper to use sklearn pipelines
sk_model = tf.keras.wrappers.scikit_learn.KerasClassifier(build_fn=build_mlp,
                                                          epochs=EPOCHS,
                                                          batch_size=BATCH_SIZE,
                                                          callbacks = CALLBACKS,
                                                          verbose=VERBOSE,
                                                          n_hidden = N_HIDDEN,
                                                          input_shape = INPUT_SHAPE,
                                                          output_shape = OUTPUT_SHAPE)

#use a pipeline
pipe = Pipeline([('scaler', MinMaxScaler()), ('mlp', sk_model)])

#cross-validation
n_splits, n_repeats = 3, 1
cv = RepeatedStratifiedKFold(n_splits=n_splits, n_repeats=n_repeats, random_state=seed)
cv_rslt = cross_validate(pipe, Xtrx, Ytr, cv=cv,
                         return_train_score = True,
                         scoring = 'accuracy',
                         return_estimator = True)

The full error I am getting is:

> ---------------------------------------------------------------------------
Empty                                     Traceback (most recent call last)
~\.conda\envs\PrognosticEnv\lib\site-packages\joblib\parallel.py in dispatch_one_batch(self, iterator)
    819             try:
--> 820                 tasks = self._ready_batches.get(block=False)
    821             except queue.Empty:

~\.conda\envs\PrognosticEnv\lib\queue.py in get(self, block, timeout)
    166                 if not self._qsize():
--> 167                     raise Empty
    168             elif timeout is None:

Empty: 

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
<ipython-input-12-47de7339b00e> in <module>
      2 n_splits, n_repeats = 3, 1
      3 cv = RepeatedStratifiedKFold(n_splits=n_splits, n_repeats=n_repeats, random_state=seed)
----> 4 cv_rslt = cross_validate(pipe, Xtrx, Ytr, cv=cv,
      5                          return_train_score = True,
      6                          scoring = 'accuracy',

~\.conda\envs\PrognosticEnv\lib\site-packages\sklearn\utils\validation.py in inner_f(*args, **kwargs)
     70                           FutureWarning)
     71         kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 72         return f(**kwargs)
     73     return inner_f
     74 

~\.conda\envs\PrognosticEnv\lib\site-packages\sklearn\model_selection\_validation.py in cross_validate(estimator, X, y, groups, scoring, cv, n_jobs, verbose, fit_params, pre_dispatch, return_train_score, return_estimator, error_score)
    240     parallel = Parallel(n_jobs=n_jobs, verbose=verbose,
    241                         pre_dispatch=pre_dispatch)
--> 242     scores = parallel(
    243         delayed(_fit_and_score)(
    244             clone(estimator), X, y, scorers, train, test, verbose, None,

~\.conda\envs\PrognosticEnv\lib\site-packages\joblib\parallel.py in __call__(self, iterable)
   1039             # remaining jobs.
   1040             self._iterating = False
-> 1041             if self.dispatch_one_batch(iterator):
   1042                 self._iterating = self._original_iterator is not None
   1043 

~\.conda\envs\PrognosticEnv\lib\site-packages\joblib\parallel.py in dispatch_one_batch(self, iterator)
    829                 big_batch_size = batch_size * n_jobs
    830 
--> 831                 islice = list(itertools.islice(iterator, big_batch_size))
    832                 if len(islice) == 0:
    833                     return False

~\.conda\envs\PrognosticEnv\lib\site-packages\sklearn\model_selection\_validation.py in <genexpr>(.0)
    242     scores = parallel(
    243         delayed(_fit_and_score)(
--> 244             clone(estimator), X, y, scorers, train, test, verbose, None,
    245             fit_params, return_train_score=return_train_score,
    246             return_times=True, return_estimator=return_estimator,

~\.conda\envs\PrognosticEnv\lib\site-packages\sklearn\utils\validation.py in inner_f(*args, **kwargs)
     70                           FutureWarning)
     71         kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 72         return f(**kwargs)
     73     return inner_f
     74 

~\.conda\envs\PrognosticEnv\lib\site-packages\sklearn\base.py in clone(estimator, safe)
     85     new_object_params = estimator.get_params(deep=False)
     86     for name, param in new_object_params.items():
---> 87         new_object_params[name] = clone(param, safe=False)
     88     new_object = klass(**new_object_params)
     89     params_set = new_object.get_params(deep=False)

~\.conda\envs\PrognosticEnv\lib\site-packages\sklearn\utils\validation.py in inner_f(*args, **kwargs)
     70                           FutureWarning)
     71         kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 72         return f(**kwargs)
     73     return inner_f
     74 

~\.conda\envs\PrognosticEnv\lib\site-packages\sklearn\base.py in clone(estimator, safe)
     66     # XXX: not handling dictionaries
     67     if estimator_type in (list, tuple, set, frozenset):
---> 68         return estimator_type([clone(e, safe=safe) for e in estimator])
     69     elif not hasattr(estimator, 'get_params') or isinstance(estimator, type):
     70         if not safe:

~\.conda\envs\PrognosticEnv\lib\site-packages\sklearn\base.py in <listcomp>(.0)
     66     # XXX: not handling dictionaries
     67     if estimator_type in (list, tuple, set, frozenset):
---> 68         return estimator_type([clone(e, safe=safe) for e in estimator])
     69     elif not hasattr(estimator, 'get_params') or isinstance(estimator, type):
     70         if not safe:

~\.conda\envs\PrognosticEnv\lib\site-packages\sklearn\utils\validation.py in inner_f(*args, **kwargs)
     70                           FutureWarning)
     71         kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 72         return f(**kwargs)
     73     return inner_f
     74 

~\.conda\envs\PrognosticEnv\lib\site-packages\sklearn\base.py in clone(estimator, safe)
     66     # XXX: not handling dictionaries
     67     if estimator_type in (list, tuple, set, frozenset):
---> 68         return estimator_type([clone(e, safe=safe) for e in estimator])
     69     elif not hasattr(estimator, 'get_params') or isinstance(estimator, type):
     70         if not safe:

~\.conda\envs\PrognosticEnv\lib\site-packages\sklearn\base.py in <listcomp>(.0)
     66     # XXX: not handling dictionaries
     67     if estimator_type in (list, tuple, set, frozenset):
---> 68         return estimator_type([clone(e, safe=safe) for e in estimator])
     69     elif not hasattr(estimator, 'get_params') or isinstance(estimator, type):
     70         if not safe:

~\.conda\envs\PrognosticEnv\lib\site-packages\sklearn\utils\validation.py in inner_f(*args, **kwargs)
     70                           FutureWarning)
     71         kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 72         return f(**kwargs)
     73     return inner_f
     74 

~\.conda\envs\PrognosticEnv\lib\site-packages\sklearn\base.py in clone(estimator, safe)
     94         param2 = params_set[name]
     95         if param1 is not param2:
---> 96             raise RuntimeError('Cannot clone object %s, as the constructor '
     97                                'either does not set or modifies parameter %s' %
     98                                (estimator, name))

RuntimeError: Cannot clone object <tensorflow.python.keras.wrappers.scikit_learn.KerasClassifier object at 0x00000245DD5C2A60>, as the constructor either does not set or modifies parameter callbacks

I have already tried putting the callback like this:

pipe.set_params(mlp__callbacks=CALLBACKS);

or putting the callback in the fit_params attribute of the cross_validate function. Nothing works for me. Someone have some suggestion?

Thank you very much


Solution

  • So finally I found a solution, actually it is more a workaround. I write it here wishing that it can be useful for some other ML practictioner. The explanation of my problem is simple and can be explained in 3 steps:

    1. sklearn do not provide a method to plot the training history of the model. I found something similar to the keras history only in the MLPclassifier that has an attribute loss_
    2. tensorflow and keras do not provide crossvalidation and pipelines routines to avoid data-leakage (since usually in deep learning there is not room for CV)
    3. wrapping a keras MLP using KerasClassifier and putting it in a sklearn pipeline is not useful because it is not possible to extrapolate the history of the classifier of the pipelin (when calling the fit function).

    So finally I used the sklearn function plot_validation_curve to create a plot of the MLP loss function in function of the training epochs. In order to avoid data-leakage I used a pipeline and the cross validation method of sklearn.

    MLP training history