I am trying to use GridSearchCV
for tuning the hyper-parameter epochs
of my model. This is the minimal, reproducible example:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
latent_dimension = 25
encoder = Encoder(latent_dimension, (28, 28, 1))
decoder = Decoder()
vae = VAE(encoder, decoder)
vae.compile(Adam()) # Compiled here
param_grid = {'epochs': [10, 20, 30]}
grid = GridSearchCV(vae, param_grid, scoring=mean_absolute_error, cv=2)
grid.fit(x_train, y_train)
But unfortunately grid.fit(x_train, y_train)
gives:
RuntimeError: You must compile your model before training/testing. Use
model.compile(optimizer, loss)
.
But I already compiled my model. How can I fix the problem?
This is VAE
, Encoder
and Decoder
implementation:
import tensorflow as tf
from keras import layers
from keras.optimizers.legacy import Adam
from sklearn.base import BaseEstimator
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import GridSearchCV
from tensorflow import keras
def sample(z_mean, z_log_var):
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.random.normal(shape=(batch, dim))
stddev = tf.exp(0.5 * z_log_var)
return z_mean + stddev * epsilon
class VAE(keras.Model, BaseEstimator):
def __init__(self, encoder, decoder, epochs=None, **kwargs):
super().__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.epochs = epochs
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
def call(self, inputs, training=None, mask=None):
_, _, z = self.encoder(inputs)
outputs = self.decoder(z)
return outputs
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
def train_step(self, data):
with tf.GradientTape() as tape:
# Forward pass
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
# Compute losses
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
# Compute gradient
grads = tape.gradient(total_loss, self.trainable_weights)
# Update weights
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
# Update my own metrics
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
def test_step(self, data):
# Forward pass
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
# Compute losses
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
# Update my own metrics
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
class Encoder(keras.Model):
def __init__(self, latent_dimension, input_shape):
super(Encoder, self).__init__()
self.latent_dim = latent_dimension
self.conv_block1 = keras.Sequential([
layers.Input(shape=input_shape),
layers.Conv2D(filters=64, kernel_size=3, activation="relu", strides=2, padding="same"),
layers.BatchNormalization()
])
self.conv_block2 = keras.Sequential([
layers.Conv2D(filters=128, kernel_size=3, activation="relu", strides=2, padding="same"),
layers.BatchNormalization()
])
self.conv_block3 = keras.Sequential([
layers.Conv2D(filters=256, kernel_size=3, activation="relu", strides=2, padding="same"),
layers.BatchNormalization()
])
self.flatten = layers.Flatten()
self.dense = layers.Dense(units=100, activation="relu")
self.z_mean = layers.Dense(latent_dimension, name="z_mean")
self.z_log_var = layers.Dense(latent_dimension, name="z_log_var")
self.sampling = sample
def call(self, inputs, training=None, mask=None):
x = self.conv_block1(inputs)
x = self.conv_block2(x)
x = self.conv_block3(x)
x = self.flatten(x)
x = self.dense(x)
z_mean = self.z_mean(x)
z_log_var = self.z_log_var(x)
z = self.sampling(z_mean, z_log_var)
return z_mean, z_log_var, z
class Decoder(keras.Model):
def __init__(self):
super(Decoder, self).__init__()
self.dense1 = keras.Sequential([
layers.Dense(units=4096, activation="relu"),
layers.BatchNormalization()
])
self.dense2 = keras.Sequential([
layers.Dense(units=1024, activation="relu"),
layers.BatchNormalization()
])
self.dense3 = keras.Sequential([
layers.Dense(units=4096, activation="relu"),
layers.BatchNormalization()
])
self.reshape = layers.Reshape((4, 4, 256))
self.deconv1 = keras.Sequential([
layers.Conv2DTranspose(filters=256, kernel_size=3, activation="relu", strides=2, padding="same"),
layers.BatchNormalization()
])
self.deconv2 = keras.Sequential([
layers.Conv2DTranspose(filters=128, kernel_size=3, activation="relu", strides=1, padding="same"),
layers.BatchNormalization()
])
self.deconv3 = keras.Sequential([
layers.Conv2DTranspose(filters=128, kernel_size=3, activation="relu", strides=2, padding="valid"),
layers.BatchNormalization()
])
self.deconv4 = keras.Sequential([
layers.Conv2DTranspose(filters=64, kernel_size=3, activation="relu", strides=1, padding="valid"),
layers.BatchNormalization()
])
self.deconv5 = keras.Sequential([
layers.Conv2DTranspose(filters=64, kernel_size=3, activation="relu", strides=2, padding="valid"),
layers.BatchNormalization()
])
self.deconv6 = layers.Conv2DTranspose(filters=1, kernel_size=2, activation="sigmoid", padding="valid")
def call(self, inputs, training=None, mask=None):
x = self.dense1(inputs)
x = self.dense2(x)
x = self.dense3(x)
x = self.reshape(x)
x = self.deconv1(x)
x = self.deconv2(x)
x = self.deconv3(x)
x = self.deconv4(x)
x = self.deconv5(x)
decoder_outputs = self.deconv6(x)
return decoder_outputs
And this is the full traceback of the error:
Traceback (most recent call last):
File "/Users/alex/PycharmProjects/VAE-EEG-XAI/del.py", line 195, in <module>
grid.fit(x_train, y_train)
File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/sklearn/model_selection/_search.py", line 874, in fit
self._run_search(evaluate_candidates)
File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/sklearn/model_selection/_search.py", line 1388, in _run_search
evaluate_candidates(ParameterGrid(self.param_grid))
File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/sklearn/model_selection/_search.py", line 851, in evaluate_candidates
_warn_or_raise_about_fit_failures(out, self.error_score)
File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/sklearn/model_selection/_validation.py", line 367, in _warn_or_raise_about_fit_failures
raise ValueError(all_fits_failed_message)
ValueError:
All the 6 fits failed.
It is very likely that your model is misconfigured.
You can try to debug the error by setting error_score='raise'.
Below are more details about the failures:
--------------------------------------------------------------------------------
6 fits failed with the following error:
Traceback (most recent call last):
File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/sklearn/model_selection/_validation.py", line 686, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/keras/src/engine/training.py", line 3875, in _assert_compile_was_called
raise RuntimeError(
RuntimeError: You must compile your model before training/testing. Use `model.compile(optimizer, loss)`.
I think I found the answer after digging through the source code.
In this line from sklearn the model gets copied and re-instantiated, and thats a problem for TensorFlow models, because at that point you have no possibility to compile the cloned model again. So I think you can not use the sklearn GridSearch natively with TF models.
You can try to use https://adriangb.com/scikeras/stable/, and wrap your model in a KerasClassifier
wrapper.
Or you write your own GridSearch loop, I did it and it is not that complicated if you only want to test easy to reach parameters. You can use sklearns ParameterGrid to create the parameter dictionaries for every run if you want to test different things.
Edit: The following code works too, aparently. It will at least start to fit, but it still crashes at the scoring step. I think you need to change the scorer (I also used a wrapper there), and change the labels (the prediction of the VAE is not a label, but a whole image again, so you have a missmatch y_prediction_shape=(100, 28, 28, 1)
, y_true_shape=(100,)
)
I changed the Encoder
and Decoder
to a Layer
class, and removed the Sequential
subclasses. I split the data
into data, labels
at the start of the train/test loops. I also fixed your Decoder
output shape (remove deconv4, changed Reshape & Dense before that) , it was missmatching to the input images. You might want to change that again. I added an ugly Wrapper class, that acts as the VAE but compiles it in the __init__()
, so it can be used even after cloning. (Remember that the cloned model starts with randomly initiated weights again, so it could perform differently with the same parameters on the same data). I'm still not sure if the wrapper class is working 100%.
import numpy as np
import tensorflow as tf
from keras import layers
from keras.optimizers.legacy import Adam
from sklearn.base import BaseEstimator
from sklearn.metrics import mean_absolute_error, make_scorer
from sklearn.model_selection import GridSearchCV
from tensorflow import keras
def flat_mae(x, y):
return mean_absolute_error(y.flatten(), x.flatten())
def sample(z_mean, z_log_var):
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.random.normal(shape=(batch, dim))
stddev = tf.exp(0.5 * z_log_var)
return z_mean + stddev * epsilon
class VAEWrapper:
def __init__(self, **kwargs):
self.vae = VAE(**kwargs)
self.vae.compile(Adam())
def fit(self, x, y, **kwargs):
self.vae.fit(x, y, **kwargs)
def get_config(self):
return self.vae.get_config()
def get_params(self, deep):
return self.vae.get_params(deep)
def set_params(self, **params):
return self.vae.set_params(**params)
class VAE(keras.Model, BaseEstimator):
def __init__(self, encoder, decoder, epochs=None, **kwargs):
super().__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.epochs = epochs
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
def call(self, inputs, training=None, mask=None):
_, _, z = self.encoder(inputs)
outputs = self.decoder(z)
return outputs
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
def train_step(self, data):
data, labels = data
with tf.GradientTape() as tape:
# Forward pass
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
# Compute losses
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
# Compute gradient
grads = tape.gradient(total_loss, self.trainable_weights)
# Update weights
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
# Update my own metrics
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
def test_step(self, data):
data, labels = data
# Forward pass
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
# Compute losses
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
# Update my own metrics
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
@keras.saving.register_keras_serializable()
class Encoder(keras.layers.Layer):
def __init__(self, latent_dimension):
super(Encoder, self).__init__()
self.latent_dim = latent_dimension
self.conv1 = layers.Conv2D(filters=64, kernel_size=3, activation="relu", strides=2, padding="same")
self.bn1 = layers.BatchNormalization()
self.conv2 = layers.Conv2D(filters=128, kernel_size=3, activation="relu", strides=2, padding="same")
self.bn2 = layers.BatchNormalization()
self.conv3 = layers.Conv2D(filters=256, kernel_size=3, activation="relu", strides=2, padding="same")
self.bn3 = layers.BatchNormalization()
self.flatten = layers.Flatten()
self.dense = layers.Dense(units=100, activation="relu")
self.z_mean = layers.Dense(latent_dimension, name="z_mean")
self.z_log_var = layers.Dense(latent_dimension, name="z_log_var")
self.sampling = sample
def call(self, inputs, training=None, mask=None):
x = self.conv1(inputs)
x = self.bn1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.flatten(x)
x = self.dense(x)
z_mean = self.z_mean(x)
z_log_var = self.z_log_var(x)
z = self.sampling(z_mean, z_log_var)
return z_mean, z_log_var, z
@keras.saving.register_keras_serializable()
class Decoder(keras.layers.Layer):
def __init__(self):
super(Decoder, self).__init__()
self.dense1 = layers.Dense(units=4096, activation="relu")
self.bn1 = layers.BatchNormalization()
self.dense2 = layers.Dense(units=1024, activation="relu")
self.bn2 = layers.BatchNormalization()
self.dense3 = layers.Dense(units=2304, activation="relu")
self.bn3 = layers.BatchNormalization()
self.reshape = layers.Reshape((3, 3, 256))
self.deconv1 = layers.Conv2DTranspose(filters=256, kernel_size=3, activation="relu", strides=2, padding="same")
self.bn4 = layers.BatchNormalization()
self.deconv2 = layers.Conv2DTranspose(filters=128, kernel_size=3, activation="relu", strides=1, padding="same")
self.bn5 = layers.BatchNormalization()
self.deconv3 = layers.Conv2DTranspose(filters=128, kernel_size=3, activation="relu", strides=2, padding="valid")
self.bn6 = layers.BatchNormalization()
# self.deconv4 = layers.Conv2DTranspose(filters=64, kernel_size=3, activation="relu", strides=1, padding="valid")
# self.bn7 = layers.BatchNormalization()
self.deconv5 = layers.Conv2DTranspose(filters=64, kernel_size=3, activation="relu", strides=2, padding="valid")
self.bn8 = layers.BatchNormalization()
self.deconv6 = layers.Conv2DTranspose(filters=1, kernel_size=2, activation="sigmoid", padding="valid")
def call(self, inputs, training=None, mask=None):
x = self.dense1(inputs)
x = self.bn1(x)
x = self.dense2(x)
x = self.bn2(x)
x = self.dense3(x)
x = self.bn3(x)
x = self.reshape(x)
x = self.deconv1(x)
x = self.bn4(x)
x = self.deconv2(x)
x = self.bn5(x)
x = self.deconv3(x)
x = self.bn6(x)
# x = self.deconv4(x)
# x = self.bn7(x)
x = self.deconv5(x)
x = self.bn8(x)
decoder_outputs = self.deconv6(x)
return decoder_outputs
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, -1).astype('float32')
x_test = np.expand_dims(x_test, -1).astype('float32')
latent_dimension = 25
param_grid = {'epochs': [10, 20, 30]}
mae_scorer = make_scorer(flat_mae, greater_is_better=False)
grid = GridSearchCV(VAEWrapper(encoder=Encoder(latent_dimension), decoder=Decoder()), param_grid, scoring=mae_scorer, cv=2, refit=False)
grid.fit(x_train, x_train)
vae = VAE(Encoder(latent_dimension), Decoder())
vae.compile(Adam())
vae.fit(x_train, xtrain, batch_size=32, epochs=grid.best_params_["epochs"])
preds = vae.predict(x_test, batch_size=32)
acc = vae.evaluate(x_test, x_test, batch_size=32)