Search code examples
pythontensorflowmachine-learningkerasdeep-learning

GridSearchCV: You must compile your model before training/testing. But my model is already compiled


I am trying to use GridSearchCV for tuning the hyper-parameter epochs of my model. This is the minimal, reproducible example:

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

latent_dimension = 25
encoder = Encoder(latent_dimension, (28, 28, 1))
decoder = Decoder()
vae = VAE(encoder, decoder)
vae.compile(Adam())  # Compiled here

param_grid = {'epochs': [10, 20, 30]}
grid = GridSearchCV(vae, param_grid, scoring=mean_absolute_error, cv=2)
grid.fit(x_train, y_train)

But unfortunately grid.fit(x_train, y_train) gives:

RuntimeError: You must compile your model before training/testing. Use model.compile(optimizer, loss).

But I already compiled my model. How can I fix the problem?

This is VAE, Encoder and Decoder implementation:

import tensorflow as tf
from keras import layers
from keras.optimizers.legacy import Adam
from sklearn.base import BaseEstimator
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import GridSearchCV
from tensorflow import keras

def sample(z_mean, z_log_var):
    batch = tf.shape(z_mean)[0]
    dim = tf.shape(z_mean)[1]
    epsilon = tf.random.normal(shape=(batch, dim))
    stddev = tf.exp(0.5 * z_log_var)
    return z_mean + stddev * epsilon

class VAE(keras.Model, BaseEstimator):
    def __init__(self, encoder, decoder, epochs=None, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.epochs = epochs
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    def call(self, inputs, training=None, mask=None):
        _, _, z = self.encoder(inputs)
        outputs = self.decoder(z)
        return outputs

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            # Forward pass
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)

            # Compute losses
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss

        # Compute gradient
        grads = tape.gradient(total_loss, self.trainable_weights)

        # Update weights
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # Update my own metrics
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

    def test_step(self, data):
        # Forward pass
        z_mean, z_log_var, z = self.encoder(data)
        reconstruction = self.decoder(z)

        # Compute losses
        reconstruction_loss = tf.reduce_mean(
            tf.reduce_sum(
                keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
            )
        )
        kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
        total_loss = reconstruction_loss + kl_loss

        # Update my own metrics
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }


class Encoder(keras.Model):
    def __init__(self, latent_dimension, input_shape):
        super(Encoder, self).__init__()
        self.latent_dim = latent_dimension
        self.conv_block1 = keras.Sequential([
            layers.Input(shape=input_shape),
            layers.Conv2D(filters=64, kernel_size=3, activation="relu", strides=2, padding="same"),
            layers.BatchNormalization()
        ])
        self.conv_block2 = keras.Sequential([
            layers.Conv2D(filters=128, kernel_size=3, activation="relu", strides=2, padding="same"),
            layers.BatchNormalization()
        ])
        self.conv_block3 = keras.Sequential([
            layers.Conv2D(filters=256, kernel_size=3, activation="relu", strides=2, padding="same"),
            layers.BatchNormalization()
        ])
        self.flatten = layers.Flatten()
        self.dense = layers.Dense(units=100, activation="relu")
        self.z_mean = layers.Dense(latent_dimension, name="z_mean")
        self.z_log_var = layers.Dense(latent_dimension, name="z_log_var")
        self.sampling = sample

    def call(self, inputs, training=None, mask=None):
        x = self.conv_block1(inputs)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.flatten(x)
        x = self.dense(x)
        z_mean = self.z_mean(x)
        z_log_var = self.z_log_var(x)
        z = self.sampling(z_mean, z_log_var)
        return z_mean, z_log_var, z


class Decoder(keras.Model):
    def __init__(self):
        super(Decoder, self).__init__()
        self.dense1 = keras.Sequential([
            layers.Dense(units=4096, activation="relu"),
            layers.BatchNormalization()
        ])
        self.dense2 = keras.Sequential([
            layers.Dense(units=1024, activation="relu"),
            layers.BatchNormalization()
        ])
        self.dense3 = keras.Sequential([
            layers.Dense(units=4096, activation="relu"),
            layers.BatchNormalization()
        ])
        self.reshape = layers.Reshape((4, 4, 256))
        self.deconv1 = keras.Sequential([
            layers.Conv2DTranspose(filters=256, kernel_size=3, activation="relu", strides=2, padding="same"),
            layers.BatchNormalization()
        ])
        self.deconv2 = keras.Sequential([
            layers.Conv2DTranspose(filters=128, kernel_size=3, activation="relu", strides=1, padding="same"),
            layers.BatchNormalization()
        ])
        self.deconv3 = keras.Sequential([
            layers.Conv2DTranspose(filters=128, kernel_size=3, activation="relu", strides=2, padding="valid"),
            layers.BatchNormalization()
        ])
        self.deconv4 = keras.Sequential([
            layers.Conv2DTranspose(filters=64, kernel_size=3, activation="relu", strides=1, padding="valid"),
            layers.BatchNormalization()
        ])
        self.deconv5 = keras.Sequential([
            layers.Conv2DTranspose(filters=64, kernel_size=3, activation="relu", strides=2, padding="valid"),
            layers.BatchNormalization()
        ])
        self.deconv6 = layers.Conv2DTranspose(filters=1, kernel_size=2, activation="sigmoid", padding="valid")

    def call(self, inputs, training=None, mask=None):
        x = self.dense1(inputs)
        x = self.dense2(x)
        x = self.dense3(x)
        x = self.reshape(x)
        x = self.deconv1(x)
        x = self.deconv2(x)
        x = self.deconv3(x)
        x = self.deconv4(x)
        x = self.deconv5(x)
        decoder_outputs = self.deconv6(x)
        return decoder_outputs

And this is the full traceback of the error:

Traceback (most recent call last):
  File "/Users/alex/PycharmProjects/VAE-EEG-XAI/del.py", line 195, in <module>
    grid.fit(x_train, y_train)
  File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/sklearn/model_selection/_search.py", line 874, in fit
    self._run_search(evaluate_candidates)
  File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/sklearn/model_selection/_search.py", line 1388, in _run_search
    evaluate_candidates(ParameterGrid(self.param_grid))
  File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/sklearn/model_selection/_search.py", line 851, in evaluate_candidates
    _warn_or_raise_about_fit_failures(out, self.error_score)
  File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/sklearn/model_selection/_validation.py", line 367, in _warn_or_raise_about_fit_failures
    raise ValueError(all_fits_failed_message)
ValueError: 
All the 6 fits failed.
It is very likely that your model is misconfigured.
You can try to debug the error by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
6 fits failed with the following error:
Traceback (most recent call last):
  File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/sklearn/model_selection/_validation.py", line 686, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/keras/src/engine/training.py", line 3875, in _assert_compile_was_called
    raise RuntimeError(
RuntimeError: You must compile your model before training/testing. Use `model.compile(optimizer, loss)`. 

Solution

  • I think I found the answer after digging through the source code. In this line from sklearn the model gets copied and re-instantiated, and thats a problem for TensorFlow models, because at that point you have no possibility to compile the cloned model again. So I think you can not use the sklearn GridSearch natively with TF models.
    You can try to use https://adriangb.com/scikeras/stable/, and wrap your model in a KerasClassifier wrapper.
    Or you write your own GridSearch loop, I did it and it is not that complicated if you only want to test easy to reach parameters. You can use sklearns ParameterGrid to create the parameter dictionaries for every run if you want to test different things.

    Edit: The following code works too, aparently. It will at least start to fit, but it still crashes at the scoring step. I think you need to change the scorer (I also used a wrapper there), and change the labels (the prediction of the VAE is not a label, but a whole image again, so you have a missmatch y_prediction_shape=(100, 28, 28, 1), y_true_shape=(100,))
    I changed the Encoder and Decoder to a Layer class, and removed the Sequential subclasses. I split the data into data, labels at the start of the train/test loops. I also fixed your Decoder output shape (remove deconv4, changed Reshape & Dense before that) , it was missmatching to the input images. You might want to change that again. I added an ugly Wrapper class, that acts as the VAE but compiles it in the __init__(), so it can be used even after cloning. (Remember that the cloned model starts with randomly initiated weights again, so it could perform differently with the same parameters on the same data). I'm still not sure if the wrapper class is working 100%.

    import numpy as np
    import tensorflow as tf
    from keras import layers
    from keras.optimizers.legacy import Adam
    from sklearn.base import BaseEstimator
    from sklearn.metrics import mean_absolute_error, make_scorer
    from sklearn.model_selection import GridSearchCV
    from tensorflow import keras
    
    def flat_mae(x, y):
        return mean_absolute_error(y.flatten(), x.flatten())
    def sample(z_mean, z_log_var):
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        stddev = tf.exp(0.5 * z_log_var)
        return z_mean + stddev * epsilon
    
    
    class VAEWrapper:
        def __init__(self, **kwargs):
            self.vae = VAE(**kwargs)
            self.vae.compile(Adam())
    
        def fit(self, x, y, **kwargs):
            self.vae.fit(x, y, **kwargs)
    
        def get_config(self):
            return self.vae.get_config()
    
        def get_params(self, deep):
            return self.vae.get_params(deep)
    
        def set_params(self, **params):
            return self.vae.set_params(**params)
    
    
    class VAE(keras.Model, BaseEstimator):
        def __init__(self, encoder, decoder, epochs=None, **kwargs):
            super().__init__(**kwargs)
            self.encoder = encoder
            self.decoder = decoder
            self.epochs = epochs
            self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
            self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
            self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
    
        def call(self, inputs, training=None, mask=None):
            _, _, z = self.encoder(inputs)
            outputs = self.decoder(z)
            return outputs
    
        @property
        def metrics(self):
            return [
                self.total_loss_tracker,
                self.reconstruction_loss_tracker,
                self.kl_loss_tracker,
            ]
    
        def train_step(self, data):
            data, labels = data
            with tf.GradientTape() as tape:
                # Forward pass
                z_mean, z_log_var, z = self.encoder(data)
                reconstruction = self.decoder(z)
    
                # Compute losses
                reconstruction_loss = tf.reduce_mean(
                    tf.reduce_sum(
                        keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                    )
                )
                kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
                kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
                total_loss = reconstruction_loss + kl_loss
    
            # Compute gradient
            grads = tape.gradient(total_loss, self.trainable_weights)
    
            # Update weights
            self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
    
            # Update my own metrics
            self.total_loss_tracker.update_state(total_loss)
            self.reconstruction_loss_tracker.update_state(reconstruction_loss)
            self.kl_loss_tracker.update_state(kl_loss)
    
            return {
                "loss": self.total_loss_tracker.result(),
                "reconstruction_loss": self.reconstruction_loss_tracker.result(),
                "kl_loss": self.kl_loss_tracker.result(),
            }
    
        def test_step(self, data):
            data, labels = data
            # Forward pass
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
    
            # Compute losses
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
    
            # Update my own metrics
            self.total_loss_tracker.update_state(total_loss)
            self.reconstruction_loss_tracker.update_state(reconstruction_loss)
            self.kl_loss_tracker.update_state(kl_loss)
    
            return {
                "loss": self.total_loss_tracker.result(),
                "reconstruction_loss": self.reconstruction_loss_tracker.result(),
                "kl_loss": self.kl_loss_tracker.result(),
            }
    
    @keras.saving.register_keras_serializable()
    class Encoder(keras.layers.Layer):
        def __init__(self, latent_dimension):
            super(Encoder, self).__init__()
            self.latent_dim = latent_dimension
            self.conv1 = layers.Conv2D(filters=64, kernel_size=3, activation="relu", strides=2, padding="same")
            self.bn1 = layers.BatchNormalization()
    
            self.conv2 = layers.Conv2D(filters=128, kernel_size=3, activation="relu", strides=2, padding="same")
            self.bn2 = layers.BatchNormalization()
    
            self.conv3 = layers.Conv2D(filters=256, kernel_size=3, activation="relu", strides=2, padding="same")
            self.bn3 = layers.BatchNormalization()
    
            self.flatten = layers.Flatten()
            self.dense = layers.Dense(units=100, activation="relu")
            self.z_mean = layers.Dense(latent_dimension, name="z_mean")
            self.z_log_var = layers.Dense(latent_dimension, name="z_log_var")
            self.sampling = sample
    
        def call(self, inputs, training=None, mask=None):
            x = self.conv1(inputs)
            x = self.bn1(x)
            x = self.conv2(x)
            x = self.bn2(x)
            x = self.conv3(x)
            x = self.bn3(x)
            x = self.flatten(x)
            x = self.dense(x)
            z_mean = self.z_mean(x)
            z_log_var = self.z_log_var(x)
            z = self.sampling(z_mean, z_log_var)
            return z_mean, z_log_var, z
    
    
    @keras.saving.register_keras_serializable()
    class Decoder(keras.layers.Layer):
        def __init__(self):
            super(Decoder, self).__init__()
            self.dense1 = layers.Dense(units=4096, activation="relu")
            self.bn1 = layers.BatchNormalization()
    
            self.dense2 = layers.Dense(units=1024, activation="relu")
            self.bn2 = layers.BatchNormalization()
    
            self.dense3 = layers.Dense(units=2304, activation="relu")
            self.bn3 = layers.BatchNormalization()
    
            self.reshape = layers.Reshape((3, 3, 256))
            self.deconv1 = layers.Conv2DTranspose(filters=256, kernel_size=3, activation="relu", strides=2, padding="same")
            self.bn4 = layers.BatchNormalization()
    
            self.deconv2 = layers.Conv2DTranspose(filters=128, kernel_size=3, activation="relu", strides=1, padding="same")
            self.bn5 = layers.BatchNormalization()
    
            self.deconv3 = layers.Conv2DTranspose(filters=128, kernel_size=3, activation="relu", strides=2, padding="valid")
            self.bn6 = layers.BatchNormalization()
    
            # self.deconv4 = layers.Conv2DTranspose(filters=64, kernel_size=3, activation="relu", strides=1, padding="valid")
            # self.bn7 = layers.BatchNormalization()
    
            self.deconv5 = layers.Conv2DTranspose(filters=64, kernel_size=3, activation="relu", strides=2, padding="valid")
            self.bn8 = layers.BatchNormalization()
            self.deconv6 = layers.Conv2DTranspose(filters=1, kernel_size=2, activation="sigmoid", padding="valid")
    
        def call(self, inputs, training=None, mask=None):
            x = self.dense1(inputs)
            x = self.bn1(x)
            x = self.dense2(x)
            x = self.bn2(x)
            x = self.dense3(x)
            x = self.bn3(x)
            x = self.reshape(x)
            x = self.deconv1(x)
            x = self.bn4(x)
            x = self.deconv2(x)
            x = self.bn5(x)
            x = self.deconv3(x)
            x = self.bn6(x)
            # x = self.deconv4(x)
            # x = self.bn7(x)
            x = self.deconv5(x)
            x = self.bn8(x)
            decoder_outputs = self.deconv6(x)
            return decoder_outputs
    
    
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    x_train = np.expand_dims(x_train, -1).astype('float32')
    x_test = np.expand_dims(x_test, -1).astype('float32')
    
    latent_dimension = 25
    param_grid = {'epochs': [10, 20, 30]}
    
    mae_scorer = make_scorer(flat_mae, greater_is_better=False)
    grid = GridSearchCV(VAEWrapper(encoder=Encoder(latent_dimension), decoder=Decoder()), param_grid, scoring=mae_scorer, cv=2, refit=False)
    grid.fit(x_train, x_train)
    vae = VAE(Encoder(latent_dimension), Decoder())
    vae.compile(Adam())
    vae.fit(x_train, xtrain, batch_size=32, epochs=grid.best_params_["epochs"])
    preds = vae.predict(x_test, batch_size=32)
    acc = vae.evaluate(x_test, x_test, batch_size=32)