Search code examples
pythontensorflowmachine-learningkeras

model.load_weights() does not load the weights I stored previously using model.save_weights()


I aim to save and then load the weights of my model using save_weights and load_weights functions.

In order to show you a minimal reproducible example, these are the dependencies you can use in my whole example:

import numpy as np
import tensorflow as tf
from keras.initializers import he_uniform
from keras.layers import Conv2DTranspose, BatchNormalization, Reshape, Dense, Conv2D, Flatten
from keras.optimizers.legacy import Adam
from keras.src.datasets import mnist
from skimage.transform import resize
from sklearn.base import BaseEstimator
from tensorflow import keras

This is my model, a (variational) autoencoder:

class VAE(keras.Model, BaseEstimator):
    def __init__(self, encoder, decoder, epochs=None, l_rate=None, batch_size=None, patience=None, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.epochs = epochs  
        self.l_rate = l_rate  
        self.batch_size = batch_size  
        self.patience = patience 
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    def call(self, inputs, training=None, mask=None):
        _, _, z = self.encoder(inputs)
        outputs = self.decoder(z)
        return outputs

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        data, labels = data
        with tf.GradientTape() as tape:
            # Forward pass
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)

            # Compute losses
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss

        # Compute gradient
        grads = tape.gradient(total_loss, self.trainable_weights)

        # Update weights
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # Update metrics
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

    def test_step(self, data):
        data, labels = data
        # Forward pass
        z_mean, z_log_var, z = self.encoder(data)
        reconstruction = self.decoder(z)

        # Compute losses
        reconstruction_loss = tf.reduce_mean(
            tf.reduce_sum(
                keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
            )
        )
        kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
        total_loss = reconstruction_loss + kl_loss

        # Update metrics
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

This is the Encoder:

@keras.saving.register_keras_serializable()
class Encoder(keras.layers.Layer):
    def __init__(self, latent_dimension):
        super(Encoder, self).__init__()
        self.latent_dim = latent_dimension

        seed = 42

        self.conv1 = Conv2D(filters=64, kernel_size=3, activation="relu", strides=2, padding="same",
                            kernel_initializer=he_uniform(seed))
        self.bn1 = BatchNormalization()

        self.conv2 = Conv2D(filters=128, kernel_size=3, activation="relu", strides=2, padding="same",
                            kernel_initializer=he_uniform(seed))
        self.bn2 = BatchNormalization()

        self.conv3 = Conv2D(filters=256, kernel_size=3, activation="relu", strides=2, padding="same",
                            kernel_initializer=he_uniform(seed))
        self.bn3 = BatchNormalization()

        self.flatten = Flatten()
        self.dense = Dense(units=100, activation="relu")

        self.z_mean = Dense(latent_dimension, name="z_mean")
        self.z_log_var = Dense(latent_dimension, name="z_log_var")

        self.sampling = sample

    def call(self, inputs, training=None, mask=None):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.flatten(x)
        x = self.dense(x)
        z_mean = self.z_mean(x)
        z_log_var = self.z_log_var(x)
        z = self.sampling(z_mean, z_log_var)
        return z_mean, z_log_var, z

Where sample function is defined below:

def sample(z_mean, z_log_var):
    batch = tf.shape(z_mean)[0]
    dim = tf.shape(z_mean)[1]
    epsilon = tf.random.normal(shape=(batch, dim))
    stddev = tf.exp(0.5 * z_log_var)
    return z_mean + stddev * epsilon

And finally this is the Decoder:

@keras.saving.register_keras_serializable()
class Decoder(keras.layers.Layer):
    def __init__(self):
        super(Decoder, self).__init__()
        self.dense1 = Dense(units=4096, activation="relu")
        self.bn1 = BatchNormalization()

        self.dense2 = Dense(units=1024, activation="relu")
        self.bn2 = BatchNormalization()

        self.dense3 = Dense(units=4096, activation="relu")
        self.bn3 = BatchNormalization()

        seed = 42

        self.reshape = Reshape((4, 4, 256))
        self.deconv1 = Conv2DTranspose(filters=256, kernel_size=3, activation="relu", strides=2, padding="same",
                                       kernel_initializer=he_uniform(seed))
        self.bn4 = BatchNormalization()

        self.deconv2 = Conv2DTranspose(filters=128, kernel_size=3, activation="relu", strides=1, padding="same",
                                       kernel_initializer=he_uniform(seed))
        self.bn5 = BatchNormalization()

        self.deconv3 = Conv2DTranspose(filters=128, kernel_size=3, activation="relu", strides=2, padding="valid",
                                       kernel_initializer=he_uniform(seed))
        self.bn6 = BatchNormalization()

        self.deconv4 = Conv2DTranspose(filters=64, kernel_size=3, activation="relu", strides=1, padding="valid",
                                       kernel_initializer=he_uniform(seed))
        self.bn7 = BatchNormalization()

        self.deconv5 = Conv2DTranspose(filters=64, kernel_size=3, activation="relu", strides=2, padding="valid",
                                       kernel_initializer=he_uniform(seed))
        self.bn8 = BatchNormalization()

        self.deconv6 = Conv2DTranspose(filters=1, kernel_size=2, activation="sigmoid", padding="valid",
                                       kernel_initializer=he_uniform(seed))

    def call(self, inputs, training=None, mask=None):
        x = self.dense1(inputs)
        x = self.bn1(x)
        x = self.dense2(x)
        x = self.bn2(x)
        x = self.dense3(x)
        x = self.bn3(x)
        x = self.reshape(x)
        x = self.deconv1(x)
        x = self.bn4(x)
        x = self.deconv2(x)
        x = self.bn5(x)
        x = self.deconv3(x)
        x = self.bn6(x)
        x = self.deconv4(x)
        x = self.bn7(x)
        x = self.deconv5(x)
        x = self.bn8(x)
        decoder_outputs = self.deconv6(x)
        return decoder_outputs

This is the main code:

def normalize(x):
    return (x - np.min(x)) / (np.max(x) - np.min(x))
    
def create_vae():
    latent_dimension = 25
    best_epochs = 2500
    best_l_rate = 10 ** -5
    best_batch_size = 32
    best_patience = 30

    encoder = Encoder(latent_dimension)
    decoder = Decoder()
    vae = VAE(encoder, decoder, best_epochs, best_l_rate, best_batch_size, best_patience)
    vae.compile(Adam(best_l_rate))
    return vae

if __name__ == '__main__':
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    new_shape = (40, 40)  # VAE deals with (None, 40, 40, 1) tensors
    x_train = np.array([resize(img, new_shape) for img in x_train])
    x_test = np.array([resize(img, new_shape) for img in x_test])
    x_train = np.expand_dims(x_train, axis=-1).astype("float32")
    x_test = np.expand_dims(x_test, axis=-1).astype("float32")

    x_train = normalize(x_train)
    x_test = normalize(x_test)

    # Let's consider the first 100 items only for speed purposes
    x_train = x_train[:100]
    y_train = y_train[:100]
    x_test = x_test[:100]
    y_test = y_test[:100]
    
    model = create_vae()
    model.fit(x_train, y_train, batch_size=64, epochs=10)
    weights_before_load = model.get_weights()
    model.save_weights("test-checkpoints/my-vae")

    del model

    model = create_vae()
    model.load_weights("test-checkpoints/my-vae")
    weights_after_load = model.get_weights()

    for layer_num, (w_before, w_after) in enumerate(zip(weights_before_load, weights_after_load), start=1):
        print(f"Layer {layer_num}:")
        print(f"Same weights? {w_before.all() == w_after.all()}")

And this is the output:

Layer 1:
Same weights? True

Layer 2:
Same weights? False  # WHY FALSE HERE?

Layer 3:
Same weights? True

Layer 4:
Same weights? True

Layer 5:
Same weights? True

Layer 6:
Same weights? True

But I expected the weights to be the same before and after the load! Why aren't the weights in the Layer 2 the same after I loaded them using load_weights method?

Furthermore this is a list of warnings I get:

WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.iter
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.decay
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.conv1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.conv1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn1.gamma
...
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.conv3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn3.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn3.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn3.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.bn3.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.dense.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.dense.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.z_mean.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).encoder.z_mean.bias
...
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn2.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.dense3.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.dense3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn3.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn3.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn3.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn3.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn4.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn4.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn4.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn4.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv2.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv2.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn5.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn5.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn5.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn5.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv3.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn6.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn6.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn6.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn6.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv4.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv4.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn7.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn7.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn7.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn7.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv5.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv5.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn8.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn8.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn8.moving_mean
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.bn8.moving_variance
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv6.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).decoder.deconv6.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn1.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn1.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv2.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv2.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn2.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn2.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv3.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.conv3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn3.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.bn3.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.dense.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.dense.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.z_mean.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.z_mean.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.z_log_var.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).encoder.z_log_var.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn1.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn1.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense2.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense2.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn2.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn2.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense3.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.dense3.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn3.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn3.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.deconv1.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.deconv1.bias
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn4.gamma
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.bn4.beta
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.deconv2.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).decoder.deconv2.bias
...
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'v' for (root).decoder.deconv6.kernel
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer's state 'v' for (root).decoder.deconv6.bias

Process finished with exit code 0

How do I fix the issue?

Note that I truncated the warnings output because of their size.

This is weird since if I re-run my example on another model as you can see here:

def create_model():
    model = tf.keras.Sequential([
        keras.layers.Dense(512, activation='relu', input_shape=(784,)),
        keras.layers.Dropout(0.2),
        keras.layers.Dense(10)
    ])

    model.compile(optimizer=Adam(),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

    return model

if __name__ == '__main__':
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    y_train = y_train[:100]
    y_test = y_test[:100]
    x_train = x_train[:100].reshape(-1, 28 * 28) / 255.0
    x_test = x_test[:100].reshape(-1, 28 * 28) / 255.0

    model = create_model()
    model.fit(x_train, y_train, batch_size=32, epochs=10)

    model.save_weights("test-checkpoints/my-model")
    weights_before = model.get_weights()

    del model

    model = create_model()
    model.load_weights("test-checkpoints/my-model")
    weights_after = model.get_weights()

    for layer_num, (w_before, w_after) in enumerate(zip(weights_before, weights_after), start=1):
        print(f"Layer {layer_num}:")
        print(f"Same weights? {w_before.all() == w_after.all()}")

then you can notice that model.save_weights() and model.load_weights() worked properly, since weights are all the same:

Layer 1:
Same weights? True


Layer 2:
Same weights? True


Layer 3:
Same weights? True


Layer 4:
Same weights? True

Solution

  • Before loading weights I had to train the model on a batch like this:

    vae.train_on_batch(x_train[:1], x_train[:1])