Search code examples
pythonkeras

Could not locate class 'SinePositionEncoding'. Make sure custom classes are decorated with `@keras.saving.register_keras_serializable()`


I trained transformer with SinePositionEncoding and TransformerEncoder:

Instead of using from keras-hub from pip due to error in Windows due to numpy compatibility, I directly copy the source code like below:

#@title Define SinePositionEncoding and TransformerEncoder

import keras
from keras.api import ops
from absl import logging
from keras.api import layers
from keras.api.models import Model

def clone_initializer(initializer):

    # If we get a string or dict, just return as we cannot and should not clone.
    if not isinstance(initializer, keras.initializers.Initializer):
        return initializer
    config = initializer.get_config()
    return initializer.__class__.from_config(config)

def _check_masks_shapes(inputs, padding_mask, attention_mask):
    mask = padding_mask
    if hasattr(inputs, "_keras_mask") and mask is None:
        mask = inputs._keras_mask
    if mask is not None:
        if len(mask.shape) != 2:
            raise ValueError(
                "`padding_mask` should have shape "
                "(batch_size, target_length). "
                f"Received shape `{mask.shape}`."
            )
    if attention_mask is not None:
        if len(attention_mask.shape) != 3:
            raise ValueError(
                "`attention_mask` should have shape "
                "(batch_size, target_length, source_length). "
                f"Received shape `{mask.shape}`."
            )

def merge_padding_and_attention_mask(
    inputs,
    padding_mask,
    attention_mask,
):

    _check_masks_shapes(inputs, padding_mask, attention_mask)
    mask = padding_mask
    if hasattr(inputs, "_keras_mask"):
        if mask is None:
            # If no padding mask is explicitly provided, we look for padding
            # mask from the input data.
            mask = inputs._keras_mask
        else:
            logging.warning(
                "You are explicitly setting `padding_mask` while the `inputs` "
                "have built-in mask, so the built-in mask is ignored."
            )
    if mask is not None:
        # Add an axis for broadcasting, the attention mask should be 2D
        # (not including the batch axis).
        mask = ops.cast(ops.expand_dims(mask, axis=1), "int32")
    if attention_mask is not None:
        attention_mask = ops.cast(attention_mask, "int32")
        if mask is None:
            return attention_mask
        else:
            return ops.minimum(mask, attention_mask)
    return mask

class SinePositionEncoding(keras.layers.Layer):

    def __init__(
        self,
        max_wavelength=10000,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.max_wavelength = max_wavelength
        self.built = True

    def call(self, inputs, start_index=0):
        shape = ops.shape(inputs)
        seq_length = shape[-2]
        hidden_size = shape[-1]
        positions = ops.arange(seq_length)
        positions = ops.cast(positions + start_index, self.compute_dtype)
        min_freq = ops.cast(1 / self.max_wavelength, dtype=self.compute_dtype)
        timescales = ops.power(
            min_freq,
            ops.cast(2 * (ops.arange(hidden_size) // 2), self.compute_dtype)
            / ops.cast(hidden_size, self.compute_dtype),
        )
        angles = ops.expand_dims(positions, 1) * ops.expand_dims(timescales, 0)
        # even indices are sine, odd are cosine
        cos_mask = ops.cast(ops.arange(hidden_size) % 2, self.compute_dtype)
        sin_mask = 1 - cos_mask
        # embedding shape is [seq_length, hidden_size]
        positional_encodings = (
            ops.sin(angles) * sin_mask + ops.cos(angles) * cos_mask
        )

        return ops.broadcast_to(positional_encodings, shape)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "max_wavelength": self.max_wavelength,
            }
        )
        return config

    def compute_output_shape(self, input_shape):
        return input_shape

class TransformerEncoder(keras.layers.Layer):

    def __init__(
        self,
        intermediate_dim,
        num_heads,
        dropout=0,
        activation="relu",
        layer_norm_epsilon=1e-05,
        kernel_initializer="glorot_uniform",
        bias_initializer="zeros",
        normalize_first=False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.intermediate_dim = intermediate_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.activation = keras.activations.get(activation)
        self.layer_norm_epsilon = layer_norm_epsilon
        self.kernel_initializer = keras.initializers.get(kernel_initializer)
        self.bias_initializer = keras.initializers.get(bias_initializer)
        self.normalize_first = normalize_first
        self.supports_masking = True

    def build(self, inputs_shape):
        # Infer the dimension of our hidden feature size from the build shape.
        hidden_dim = inputs_shape[-1]
        # Attention head size is `hidden_dim` over the number of heads.
        key_dim = int(hidden_dim // self.num_heads)
        if key_dim == 0:
            raise ValueError(
                "Attention `key_dim` computed cannot be zero. "
                f"The `hidden_dim` value of {hidden_dim} has to be equal to "
                f"or greater than `num_heads` value of {self.num_heads}."
            )

        # Self attention layers.
        self._self_attention_layer = keras.layers.MultiHeadAttention(
            num_heads=self.num_heads,
            key_dim=key_dim,
            dropout=self.dropout,
            kernel_initializer=clone_initializer(self.kernel_initializer),
            bias_initializer=clone_initializer(self.bias_initializer),
            dtype=self.dtype_policy,
            name="self_attention_layer",
        )
        if hasattr(self._self_attention_layer, "_build_from_signature"):
            self._self_attention_layer._build_from_signature(
                query=inputs_shape,
                value=inputs_shape,
            )
        else:
            self._self_attention_layer.build(
                query_shape=inputs_shape,
                value_shape=inputs_shape,
            )
        self._self_attention_layer_norm = keras.layers.LayerNormalization(
            epsilon=self.layer_norm_epsilon,
            dtype=self.dtype_policy,
            name="self_attention_layer_norm",
        )
        self._self_attention_layer_norm.build(inputs_shape)
        self._self_attention_dropout = keras.layers.Dropout(
            rate=self.dropout,
            dtype=self.dtype_policy,
            name="self_attention_dropout",
        )

        # Feedforward layers.
        self._feedforward_layer_norm = keras.layers.LayerNormalization(
            epsilon=self.layer_norm_epsilon,
            dtype=self.dtype_policy,
            name="feedforward_layer_norm",
        )
        self._feedforward_layer_norm.build(inputs_shape)
        self._feedforward_intermediate_dense = keras.layers.Dense(
            self.intermediate_dim,
            activation=self.activation,
            kernel_initializer=clone_initializer(self.kernel_initializer),
            bias_initializer=clone_initializer(self.bias_initializer),
            dtype=self.dtype_policy,
            name="feedforward_intermediate_dense",
        )
        self._feedforward_intermediate_dense.build(inputs_shape)
        self._feedforward_output_dense = keras.layers.Dense(
            hidden_dim,
            kernel_initializer=clone_initializer(self.kernel_initializer),
            bias_initializer=clone_initializer(self.bias_initializer),
            dtype=self.dtype_policy,
            name="feedforward_output_dense",
        )
        intermediate_shape = list(inputs_shape)
        intermediate_shape[-1] = self.intermediate_dim
        self._feedforward_output_dense.build(tuple(intermediate_shape))
        self._feedforward_dropout = keras.layers.Dropout(
            rate=self.dropout,
            dtype=self.dtype_policy,
            name="feedforward_dropout",
        )
        self.built = True

    def call(
        self,
        inputs,
        padding_mask=None,
        attention_mask=None,
        training=None,
        return_attention_scores=False,
    ):

        x = inputs  # Intermediate result.

        # Compute self attention mask.
        self_attention_mask = merge_padding_and_attention_mask(
            inputs, padding_mask, attention_mask
        )

        # Self attention block.
        residual = x
        if self.normalize_first:
            x = self._self_attention_layer_norm(x)

        if return_attention_scores:
            x, attention_scores = self._self_attention_layer(
                query=x,
                value=x,
                attention_mask=self_attention_mask,
                return_attention_scores=return_attention_scores,
                training=training,
            )
            return x, attention_scores
        else:
            x = self._self_attention_layer(
                query=x,
                value=x,
                attention_mask=self_attention_mask,
                training=training,
            )

        x = self._self_attention_dropout(x, training=training)
        x = x + residual
        if not self.normalize_first:
            x = self._self_attention_layer_norm(x)

        # Feedforward block.
        residual = x
        if self.normalize_first:
            x = self._feedforward_layer_norm(x)
        x = self._feedforward_intermediate_dense(x)
        x = self._feedforward_output_dense(x)
        x = self._feedforward_dropout(x, training=training)
        x = x + residual
        if not self.normalize_first:
            x = self._feedforward_layer_norm(x)

        if return_attention_scores:
            return x, attention_scores

        return x

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "intermediate_dim": self.intermediate_dim,
                "num_heads": self.num_heads,
                "dropout": self.dropout,
                "activation": keras.activations.serialize(self.activation),
                "layer_norm_epsilon": self.layer_norm_epsilon,
                "kernel_initializer": keras.initializers.serialize(
                    self.kernel_initializer
                ),
                "bias_initializer": keras.initializers.serialize(
                    self.bias_initializer
                ),
                "normalize_first": self.normalize_first,
            }
        )
        return config

    def compute_output_shape(self, inputs_shape):
        return inputs_shape

Then I trained them using this architecture:

def get_model():
  encoder_inputs = layers.Input(shape=(240,), name="encoder_inputs", dtype='uint8')

  # embeddings
  token_embeddings = layers.Embedding(input_dim=255, output_dim=128)(encoder_inputs) # Input: Token Size, Output: Embed Dim
  position_encodings = SinePositionEncoding()(token_embeddings)
  embeddings = token_embeddings + position_encodings

  # transformer encoder
  encoder_outputs = TransformerEncoder(intermediate_dim=128*2, num_heads=4, dropout=0.01)(inputs=embeddings)

  # Output layer for vocabulary size of 4
  output_predictions = layers.Dense(units=4, activation=None)(encoder_outputs)

  # Final model
  model = Model(encoder_inputs, output_predictions, name="transformer_encoder")

  return model

I saved the model with model.save('model_best.keras').

Trying to load model and corresponded cell from source code using keras.saving.load_model('model_best.keras'), returning error:

    718     instance = cls.from_config(inner_config)
    719 except TypeError as e:
--> 720     raise TypeError(
    721         f"{cls} could not be deserialized properly. Please"
    722         " ensure that components that are Python object"
    723         " instances (layers, models, etc.) returned by"
    724         " `get_config()` are explicitly deserialized in the"
    725         " model's `from_config()` method."
    726         f"\n\nconfig={config}.\n\nException encountered: {e}"
    727     )
    728 build_config = config.get("build_config", None)
    729 if build_config and not instance.built:

TypeError: <class 'keras.src.models.functional.Functional'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by `get_config()` are explicitly deserialized in the model's `from_config()` method.

config={'module': 'keras.src.models.functional', 'class_name': 'Functional', 'config': {}, 'registered_name': 'Functional', 'build_config': {'input_shape': None}, 'compile_config': {'optimizer': {'module': 'keras.optimizers', 'class_name': 'Adam', 'config': {'name': 'adam', 'learning_rate': 0.0005000000237487257, 'weight_decay': None, 'clipnorm': None, 'global_clipnorm': None, 'clipvalue': None, 'use_ema': False, 'ema_momentum': 0.99, 'ema_overwrite_frequency': None, 'loss_scale_factor': None, 'gradient_accumulation_steps': None, 'beta_1': 0.9, 'beta_2': 0.999, 'epsilon': 1e-07, 'amsgrad': False}, 'registered_name': None}, 'loss': {'module': 'keras.losses', 'class_name': 'SparseCategoricalCrossentropy', 'config': {'name': 'sparse_categorical_crossentropy', 'reduction': 'sum_over_batch_size', 'from_logits': True, 'ignore_class': None}, 'registered_name': None}, 'loss_weights': None, 'metrics': ['accuracy'], 'weighted_metrics': None, 'run_eagerly': False, 'steps_per_execution': 1, 'jit_compile': True}}.

Exception encountered: Could not locate class 'SinePositionEncoding'. Make sure custom classes are decorated with `@keras.saving.register_keras_serializable()`. Full object config: {'module': None, 'class_name': 'SinePositionEncoding', 'config': {'name': 'sine_position_encoding', 'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None, 'shared_object_id': 133241101393904}, 'max_wavelength': 10000}, 'registered_name': 'SinePositionEncoding', 'name': 'sine_position_encoding', 'inbound_nodes': [{'args': [{'class_name': '__keras_tensor__', 'config': {'shape': [None, 240, 128], 'dtype': 'float32', 'keras_history': ['embedding', 0, 0]}}], 'kwargs': {}}]}

How do I access my model that has been trained for 1 hour long?


Solution

  • Didn't expect, ChatGPT give me worked solution this time, adding this line will registered the layer.

    # Register the layers manually
    keras.utils.get_custom_objects().update({
        "SinePositionEncoding": SinePositionEncoding,
        "TransformerEncoder": TransformerEncoder,
    })
    

    Then, just load as usual:

    import keras
    
    model = keras.saving.load_model('delineator_base_d128_80hz_best.keras')
    model.summary()
    

    Outputting:

    Model: "transformer_encoder"
    ┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
    ┃ Layer (type)        ┃ Output Shape      ┃    Param # ┃ Connected to      ┃
    ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
    │ encoder_inputs      │ (None, 240)       │          0 │ -                 │
    │ (InputLayer)        │                   │            │                   │
    ├─────────────────────┼───────────────────┼────────────┼───────────────────┤
    │ embedding           │ (None, 240, 128)  │     32,640 │ encoder_inputs[0… │
    │ (Embedding)         │                   │            │                   │
    ├─────────────────────┼───────────────────┼────────────┼───────────────────┤
    │ sine_position_enco… │ (None, 240, 128)  │          0 │ embedding[0][0]   │
    │ (SinePositionEncod… │                   │            │                   │
    ├─────────────────────┼───────────────────┼────────────┼───────────────────┤
    │ add (Add)           │ (None, 240, 128)  │          0 │ embedding[0][0],  │
    │                     │                   │            │ sine_position_en… │
    ├─────────────────────┼───────────────────┼────────────┼───────────────────┤
    │ transformer_encoder │ (None, 240, 128)  │    132,480 │ add[0][0]         │
    │ (TransformerEncode… │                   │            │                   │
    ├─────────────────────┼───────────────────┼────────────┼───────────────────┤
    │ dense (Dense)       │ (None, 240, 4)    │        516 │ transformer_enco… │
    └─────────────────────┴───────────────────┴────────────┴───────────────────┘
     Total params: 496,910 (1.90 MB)
     Trainable params: 165,636 (647.02 KB)
     Non-trainable params: 0 (0.00 B)
     Optimizer params: 331,274 (1.26 MB)