I aim to save and then load the weights of my model using save_weights
and load_weights
functions.
In order to show you a minimal reproducible example, these are the dependencies you can use in my whole example:
import numpy as np
import tensorflow as tf
from keras.initializers import he_uniform
from keras.layers import Conv2DTranspose, BatchNormalization, Reshape, Dense, Conv2D, Flatten
from keras.optimizers.legacy import Adam
from keras.src.datasets import mnist
from skimage.transform import resize
from sklearn.base import BaseEstimator
from tensorflow import keras
This is my model, a (variational) autoencoder:
class VAE(keras.Model, BaseEstimator):
def __init__(self, encoder, decoder, epochs=None, l_rate=None, batch_size=None, patience=None, **kwargs):
super().__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.epochs = epochs
self.l_rate = l_rate
self.batch_size = batch_size
self.patience = patience
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
def call(self, inputs, training=None, mask=None):
_, _, z = self.encoder(inputs)
outputs = self.decoder(z)
return outputs
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
def train_step(self, data):
data, labels = data
with tf.GradientTape() as tape:
# Forward pass
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
# Compute losses
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
# Compute gradient
grads = tape.gradient(total_loss, self.trainable_weights)
# Update weights
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
# Update metrics
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
def test_step(self, data):
data, labels = data
# Forward pass
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
# Compute losses
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
# Update metrics
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
This is the Encoder:
@keras.saving.register_keras_serializable()
class Encoder(keras.layers.Layer):
def __init__(self, latent_dimension):
super(Encoder, self).__init__()
self.latent_dim = latent_dimension
seed = 42
self.conv1 = Conv2D(filters=64, kernel_size=3, activation="relu", strides=2, padding="same",
kernel_initializer=he_uniform(seed))
self.bn1 = BatchNormalization()
self.conv2 = Conv2D(filters=128, kernel_size=3, activation="relu", strides=2, padding="same",
kernel_initializer=he_uniform(seed))
self.bn2 = BatchNormalization()
self.conv3 = Conv2D(filters=256, kernel_size=3, activation="relu", strides=2, padding="same",
kernel_initializer=he_uniform(seed))
self.bn3 = BatchNormalization()
self.flatten = Flatten()
self.dense = Dense(units=100, activation="relu")
self.z_mean = Dense(latent_dimension, name="z_mean")
self.z_log_var = Dense(latent_dimension, name="z_log_var")
self.sampling = sample
def call(self, inputs, training=None, mask=None):
x = self.conv1(inputs)
x = self.bn1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.flatten(x)
x = self.dense(x)
z_mean = self.z_mean(x)
z_log_var = self.z_log_var(x)
z = self.sampling(z_mean, z_log_var)
return z_mean, z_log_var, z
Where sample
function is defined below:
def sample(z_mean, z_log_var):
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.random.normal(shape=(batch, dim))
stddev = tf.exp(0.5 * z_log_var)
return z_mean + stddev * epsilon
And finally this is the Decoder:
@keras.saving.register_keras_serializable()
class Decoder(keras.layers.Layer):
def __init__(self):
super(Decoder, self).__init__()
self.dense1 = Dense(units=4096, activation="relu")
self.bn1 = BatchNormalization()
self.dense2 = Dense(units=1024, activation="relu")
self.bn2 = BatchNormalization()
self.dense3 = Dense(units=4096, activation="relu")
self.bn3 = BatchNormalization()
seed = 42
self.reshape = Reshape((4, 4, 256))
self.deconv1 = Conv2DTranspose(filters=256, kernel_size=3, activation="relu", strides=2, padding="same",
kernel_initializer=he_uniform(seed))
self.bn4 = BatchNormalization()
self.deconv2 = Conv2DTranspose(filters=128, kernel_size=3, activation="relu", strides=1, padding="same",
kernel_initializer=he_uniform(seed))
self.bn5 = BatchNormalization()
self.deconv3 = Conv2DTranspose(filters=128, kernel_size=3, activation="relu", strides=2, padding="valid",
kernel_initializer=he_uniform(seed))
self.bn6 = BatchNormalization()
self.deconv4 = Conv2DTranspose(filters=64, kernel_size=3, activation="relu", strides=1, padding="valid",
kernel_initializer=he_uniform(seed))
self.bn7 = BatchNormalization()
self.deconv5 = Conv2DTranspose(filters=64, kernel_size=3, activation="relu", strides=2, padding="valid",
kernel_initializer=he_uniform(seed))
self.bn8 = BatchNormalization()
self.deconv6 = Conv2DTranspose(filters=1, kernel_size=2, activation="sigmoid", padding="valid",
kernel_initializer=he_uniform(seed))
def call(self, inputs, training=None, mask=None):
x = self.dense1(inputs)
x = self.bn1(x)
x = self.dense2(x)
x = self.bn2(x)
x = self.dense3(x)
x = self.bn3(x)
x = self.reshape(x)
x = self.deconv1(x)
x = self.bn4(x)
x = self.deconv2(x)
x = self.bn5(x)
x = self.deconv3(x)
x = self.bn6(x)
x = self.deconv4(x)
x = self.bn7(x)
x = self.deconv5(x)
x = self.bn8(x)
decoder_outputs = self.deconv6(x)
return decoder_outputs
This is the main
code:
def normalize(x):
return (x - np.min(x)) / (np.max(x) - np.min(x))
def create_vae():
latent_dimension = 25
best_epochs = 2500
best_l_rate = 10 ** -5
best_batch_size = 32
best_patience = 30
encoder = Encoder(latent_dimension)
decoder = Decoder()
vae = VAE(encoder, decoder, best_epochs, best_l_rate, best_batch_size, best_patience)
vae.compile(Adam(best_l_rate))
return vae
if __name__ == '__main__':
(x_train, y_train), (x_test, y_test) = mnist.load_data()
new_shape = (40, 40) # VAE deals with (None, 40, 40, 1) tensors
x_train = np.array([resize(img, new_shape) for img in x_train])
x_test = np.array([resize(img, new_shape) for img in x_test])
x_train = np.expand_dims(x_train, axis=-1).astype("float32")
x_test = np.expand_dims(x_test, axis=-1).astype("float32")
x_train = normalize(x_train)
x_test = normalize(x_test)
# Let's consider the first 100 items only for speed purposes
x_train = x_train[:100]
y_train = y_train[:100]
x_test = x_test[:100]
y_test = y_test[:100]
model = create_vae()
model.fit(x_train, y_train, batch_size=64, epochs=10)
weights_before_load = model.get_weights()
model.save_weights("test-checkpoints/my-vae")
del model
model = create_vae()
model.load_weights("test-checkpoints/my-vae")
weights_after_load = model.get_weights()
for layer_num, (w_before, w_after) in enumerate(zip(weights_before_load, weights_after_load), start=1):
print(f"Layer {layer_num}:")
print(f"Same weights? {w_before.all() == w_after.all()}")
And this is the output:
Layer 1:
Same weights? True
Layer 2:
Same weights? False # WHY FALSE HERE?
Layer 3:
Same weights? True
Layer 4:
Same weights? True
Layer 5:
Same weights? True
Layer 6:
Same weights? True
But I expected the weights to be the same before and after the load! Why aren't the weights in the Layer 2 the same after I loaded them using load_weights
method?
Furthermore this is a list of warnings I get:
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.iter
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.decay
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.conv1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.conv1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn1.gamma
...
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.conv3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn3.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn3.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn3.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn3.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.dense.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.dense.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.z_mean.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.z_mean.bias
...
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn2.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.dense3.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.dense3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn3.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn3.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn3.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn3.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn4.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn4.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn4.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn4.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv2.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv2.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn5.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn5.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn5.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn5.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv3.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn6.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn6.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn6.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn6.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv4.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv4.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn7.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn7.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn7.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn7.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv5.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv5.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn8.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn8.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn8.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn8.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv6.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv6.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn1.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn1.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv2.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv2.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn2.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn2.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv3.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn3.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn3.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.dense.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.dense.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.z_mean.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.z_mean.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.z_log_var.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.z_log_var.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn1.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn1.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense2.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense2.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn2.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn2.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense3.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn3.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn3.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.deconv1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.deconv1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn4.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn4.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.deconv2.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.deconv2.bias
...
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'v' for (root).decoder.deconv6.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'v' for (root).decoder.deconv6.bias
Process finished with exit code 0
How do I fix the issue?
Note that I truncated the warnings output because of their size.
This is weird since if I re-run my example on another model as you can see here:
def create_model():
model = tf.keras.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10)
])
model.compile(optimizer=Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
return model
if __name__ == '__main__':
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
y_train = y_train[:100]
y_test = y_test[:100]
x_train = x_train[:100].reshape(-1, 28 * 28) / 255.0
x_test = x_test[:100].reshape(-1, 28 * 28) / 255.0
model = create_model()
model.fit(x_train, y_train, batch_size=32, epochs=10)
model.save_weights("test-checkpoints/my-model")
weights_before = model.get_weights()
del model
model = create_model()
model.load_weights("test-checkpoints/my-model")
weights_after = model.get_weights()
for layer_num, (w_before, w_after) in enumerate(zip(weights_before, weights_after), start=1):
print(f"Layer {layer_num}:")
print(f"Same weights? {w_before.all() == w_after.all()}")
then you can notice that model.save_weights()
and model.load_weights()
worked properly, since weights are all the same:
Layer 1:
Same weights? True
Layer 2:
Same weights? True
Layer 3:
Same weights? True
Layer 4:
Same weights? True
Before loading weights I had to train the model on a batch like this:
vae.train_on_batch(x_train[:1], x_train[:1])