I have implemented a Transformer encoder in Keras using the template provided by Francois Chollet here. After I train the model, I save it using model.save
, but when I load it again for inference I find that the weights seem to be random again, and therefore my model loses all inference ability.
I have looked at similar issues on Stack Overflow and GitHub, and applied the following suggestions, but I am still getting the same issue:
@tf.keras.utils.register_keras_serializable()
decorator on the class.**kwargs
is in the init callget_config
and from_config
methods.custom_object_scope
to load model.Below is a minimally reproducible example to replicate the issue. How do I change it so that the model weights save correctly?
import numpy as np
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras import layers
from keras.models import load_model
from keras.utils import custom_object_scope
@tf.keras.utils.register_keras_serializable()
class TransformerEncoder(layers.Layer):
def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.dense_dim = dense_dim
self.num_heads = num_heads
self.attention = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim)
self.dense_proj = keras.Sequential(
[
layers.Dense(dense_dim, activation="relu"),
layers.Dense(embed_dim),
]
)
self.layernorm_1 = layers.LayerNormalization()
self.layernorm_2 = layers.LayerNormalization()
def call(self, inputs, mask=None):
if mask is not None:
mask = mask[:, tf.newaxis, :]
attention_output = self.attention(
inputs, inputs, attention_mask=mask)
proj_input = self.layernorm_1(inputs + attention_output)
proj_output = self.dense_proj(proj_input)
return self.layernorm_2(proj_input + proj_output)
def get_config(self):
config = super().get_config()
config.update({
"embed_dim": self.embed_dim,
"num_heads": self.num_heads,
"dense_dim": self.dense_dim,
})
return config
@classmethod
def from_config(cls, config):
return cls(**config)
# Create simple model:
encoder = TransformerEncoder(embed_dim=2, dense_dim=2, num_heads=1)
inputs = keras.Input(shape=(2, 2), batch_size=None, name="test_inputs")
x = encoder(inputs)
x = layers.Flatten()(x)
outputs = layers.Dense(1, activation="linear")(x)
model = keras.Model(inputs, outputs)
# Fit the model and save it:
np.random.seed(42)
X = np.random.rand(10, 2, 2)
y = np.ones(10)
model.compile(optimizer=keras.optimizers.Adam(), loss="mean_squared_error")
model.fit(X, y, epochs=2, batch_size=1)
model.save("./test_model")
# Load the saved model:
with custom_object_scope({
'TransformerEncoder': TransformerEncoder
}):
loaded_model = load_model("./test_model")
print(model.weights[0].numpy())
print(loaded_model.weights[0].numpy())
The weights are saved (you can load them with load_weights
after loading the model). The problem is that you create new layers in __init__
. You need to recreate them from their config, for example:
class TransformerEncoder(layers.Layer):
def __init__(self, embed_dim, dense_dim, num_heads, attention_config=None, dense_proj_config=None, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.dense_dim = dense_dim
self.num_heads = num_heads
self.attention = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim) \
if attention_config is None else layers.MultiHeadAttention.from_config(attention_config)
self.dense_proj = keras.Sequential(
[
layers.Dense(dense_dim, activation="relu"),
layers.Dense(embed_dim),
]
) if dense_proj_config is None else keras.Sequential.from_config(dense_proj_config)
...
def call(self, inputs, mask=None):
...
def get_config(self):
config = super().get_config()
config.update({
"embed_dim": self.embed_dim,
"num_heads": self.num_heads,
"dense_dim": self.dense_dim,
"attention_config": self.attention.get_config(),
"dense_proj_config": self.dense_proj.get_config(),
})
return config
Output:
[[[-0.810745 -0.14727005]]
[[ 0.8542909 0.09689581]]]
[[[-0.810745 -0.14727005]]
[[ 0.8542909 0.09689581]]]