Search code examples
pythontensorflowkerasloadcustom-function

Error Saving & Loading Tensorflow/Keras Model With Custom Classes/Functions


I recently created a Tensorflow/Keras model with Keras Transformers. To do this, the custom PositionalEmbedding & TransformerEncoder classes were created and used to build the model architecture. There are created as such:

class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length, output_dim, **kwargs):
        super().__init__(**kwargs)
        self.position_embeddings = layers.Embedding(
            input_dim=sequence_length, output_dim=output_dim
        )
        self.sequence_length = sequence_length
        self.output_dim = output_dim

    def call(self, inputs):
        # The inputs are of shape: `(batch_size, frames, num_features)`
        length = tf.shape(inputs)[1]
        positions = tf.range(start=0, limit=length, delta=1)
        embedded_positions = self.position_embeddings(positions)
        return inputs + embedded_positions

    def compute_mask(self, inputs, mask=None):
        mask = tf.reduce_any(tf.cast(inputs, "bool"), axis=-1)
        return mask

class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim, dropout=0.3
        )
        self.dense_proj = keras.Sequential(
            [layers.Dense(dense_dim, activation=tf.nn.gelu), layers.Dense(embed_dim),]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()

    def call(self, inputs, mask=None):
        if mask is not None:
            mask = mask[:, tf.newaxis, :]

        attention_output = self.attention(inputs, inputs, attention_mask=mask)
        proj_input = self.layernorm_1(inputs + attention_output)
        proj_output = self.dense_proj(proj_input)
        return self.layernorm_2(proj_input + proj_output)

At first, I was unable to even save this model using the typical model.save() method. However, I was able to solve for this by updating the config for the classes like so:

### FOR THE PositionalEmbedding CLASS
def get_config(self):
 
        config = super().get_config().copy()
        config.update({
            'position_embeddings': self.position_embeddings,
            'sequence_length': self.sequence_length,
            'output_dim': self.output_dim
        })
        return config
 
### FOR THE TransformerEncoder CLASS
def get_config(self):
 
        config = super().get_config().copy()
        config.update({
            'embed_dim': self.embed_dim,
            'dense_dim': self.dense_dim,
            'num_heads': self.num_heads,
            'attention': self.attention,
            'dense_proj': self.dense_proj,
            'layernorm_1': self.layernorm_1,
            'layernorm_2': self.layernorm_2
        })
        return config

However, when I try to load the model using the keras load_model() method without the custom_objects argument, I get the following error:

ValueError: Unknown layer: PositionalEmbedding. Please ensure this object is passed to the `custom_objects` argument.

And if I use the load _model() method without initializing the classes, using the custom_objects argument for the two classes as such load_model('my_model.h5', custom_objects= {'PositionalEmbedding':PositionalEmbedding,'TransformerEncoder':TransformerEncoder}), I get the following error:

NameError: name 'PositionalEmbedding' is not defined

And finally, if I do initialize the classes with the updated configs before loading, and use the load_model() method as shown in the previous example, I get the following error:

TypeError: ('Keyword argument not understood:', 'position_embeddings')

Anyone know what might be causing this issue and how I can resolve them to load this model? Any help is appreciated!

Thanks!

Sam


Solution

  • So I was actually able to solve this problem with a workaround. Instead of saving the model and loading it the old-fashioned way, I saved a checkpoint for the model while training, then loaded it by creating a new model from scratch and loading the checkpoint as the weights.

    The code for that is below:

    ### SAVING THE MODEL WITH CHECKPOINT
    filepath = "/content/drive/MyDrive/tmp/model_checkpoint.ckpt"
    checkpoint = keras.callbacks.ModelCheckpoint(
        filepath, save_weights_only=True, save_best_only=True, verbose=1
    )
    
    history = model.fit(
        train_data,
        train_labels,
        validation_split=0.3,
        epochs=250,
        batch_size=256,
        callbacks=[checkpoint],
    )
    
    ### CREATING NEW MODEL & LOADING CHECKPOINT AS WEIGHTS
    def get_compiled_model():
        sequence_length = MAX_SEQ_LENGTH
        embed_dim = NUM_FEATURES
        dense_dim = 4
        num_heads = 1
        classes = len(label_processor.get_vocabulary())
    
        inputs = keras.Input(shape=(None, None))
        x = PositionalEmbedding(
            sequence_length, embed_dim, name="frame_position_embedding"
        )(inputs)
        x = TransformerEncoder(embed_dim, dense_dim, num_heads, name="transformer_layer")(x)
        x = layers.GlobalMaxPooling1D()(x)
        x = layers.Dropout(0.5)(x)
        outputs = layers.Dense(classes, activation="softmax")(x)
        model = keras.Model(inputs, outputs)
    
        model.compile(
            optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
        )
        return model
    
    model = get_compiled_model()
    
    model.load_weights("/content/drive/MyDrive/tmp/model_checkpoint.ckpt")