Search code examples
pythontensorflowkeras

A loaded Keras model with a custom layer has different weights to model which was saved


I have implemented a Transformer encoder in Keras using the template provided by Francois Chollet here. After I train the model, I save it using model.save, but when I load it again for inference I find that the weights seem to be random again, and therefore my model loses all inference ability.

I have looked at similar issues on Stack Overflow and GitHub, and applied the following suggestions, but I am still getting the same issue:

  1. Use the @tf.keras.utils.register_keras_serializable() decorator on the class.
  2. Make sure **kwargs is in the init call
  3. Make sure the custom layer has get_config and from_config methods.
  4. Use custom_object_scope to load model.

Below is a minimally reproducible example to replicate the issue. How do I change it so that the model weights save correctly?

import numpy as np
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras import layers
from keras.models import load_model
from keras.utils import custom_object_scope

@tf.keras.utils.register_keras_serializable()
class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim)
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(dense_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()

    def call(self, inputs, mask=None):
        if mask is not None:
            mask = mask[:, tf.newaxis, :]
        attention_output = self.attention(
            inputs, inputs, attention_mask=mask)
        proj_input = self.layernorm_1(inputs + attention_output)
        proj_output = self.dense_proj(proj_input)
        return self.layernorm_2(proj_input + proj_output)

    def get_config(self):
        config = super().get_config()
        config.update({
            "embed_dim": self.embed_dim,
            "num_heads": self.num_heads,
            "dense_dim": self.dense_dim,
        })
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)


# Create simple model:
encoder = TransformerEncoder(embed_dim=2, dense_dim=2, num_heads=1)
inputs = keras.Input(shape=(2, 2), batch_size=None, name="test_inputs")
x = encoder(inputs)
x = layers.Flatten()(x)
outputs = layers.Dense(1, activation="linear")(x)
model = keras.Model(inputs, outputs)

# Fit the model and save it:
np.random.seed(42)
X = np.random.rand(10, 2, 2)
y = np.ones(10)
model.compile(optimizer=keras.optimizers.Adam(), loss="mean_squared_error")
model.fit(X, y, epochs=2, batch_size=1)
model.save("./test_model")

# Load the saved model:
with custom_object_scope({
    'TransformerEncoder': TransformerEncoder
}):
    loaded_model = load_model("./test_model")

print(model.weights[0].numpy())
print(loaded_model.weights[0].numpy())

Solution

  • The weights are saved (you can load them with load_weights after loading the model). The problem is that you create new layers in __init__. You need to recreate them from their config, for example:

    class TransformerEncoder(layers.Layer):
        def __init__(self, embed_dim, dense_dim, num_heads, attention_config=None, dense_proj_config=None, **kwargs):
            super().__init__(**kwargs)
            self.embed_dim = embed_dim
            self.dense_dim = dense_dim
            self.num_heads = num_heads
            self.attention = layers.MultiHeadAttention(
                num_heads=num_heads, key_dim=embed_dim) \
                if attention_config is None else layers.MultiHeadAttention.from_config(attention_config)
            self.dense_proj = keras.Sequential(
                [
                    layers.Dense(dense_dim, activation="relu"),
                    layers.Dense(embed_dim),
                ]
            ) if dense_proj_config is None else keras.Sequential.from_config(dense_proj_config)
            ...
    
        def call(self, inputs, mask=None):
            ...
    
        def get_config(self):
            config = super().get_config()
            config.update({
                "embed_dim": self.embed_dim,
                "num_heads": self.num_heads,
                "dense_dim": self.dense_dim,
                "attention_config": self.attention.get_config(),
                "dense_proj_config": self.dense_proj.get_config(),
            })
            return config
    

    Output:

    [[[-0.810745   -0.14727005]]
    
    [[ 0.8542909   0.09689581]]]
    [[[-0.810745   -0.14727005]]
    
    [[ 0.8542909   0.09689581]]]