Search code examples
pythonkeras

Save custom field in Keras model


Consider the situation in which I have a trained Keras Sequential model. I save the model using

keras.saving.save_model(model, path, save_format="...")

Before saving, however, I set a custom list[str] attribute in the model this way:

setattr(model, "custom_attr", ["one", "two", "three"])

And finally, when I reload the model object (from another project) with keras.saving.load_model, I would like to have my custom attribute available via model.custom_attr. However, this doesn't work as custom_attr doesn't exist anymore after reloading the model.

Is there any way to do that?

I looked up a bit and it seems you can specify a custom_objects parameter when reloading the model, but that method seems to be limited to custom layers or custom loss functions defined in a custom model class. My setting is completely different as I have a normal Sequential model.


Solution

  • I solved the issue by subclassing the Sequential class and adding a parameter to the constructor:

    class SequentialWithCustomAttr(keras.Sequential):
        def __init__(self, custom_attr=[], layers=None, *args, **kwargs):
            super().__init__(layers=layers, trainable=kwargs["trainable"], name=kwargs["name"])
            self.custom_attr = custom_attr
    
        def get_config(self):
            config = super().get_config()
            config.update({"custom_attr": self.custom_attr})
    
            return config
    
        @classmethod
        def from_config(cls, config, custom_objects=None):
            custom_attr = config.pop("custom_attr", None)
    
            # Deserialize layers one by one by using their configs.
            layers_confs = config.pop("layers", None)
            layers = list(map(keras.saving.deserialize_keras_object, layers_confs))
    
            # Create an instance of class SequentialWithCustomAttr.
            model = cls(custom_attr=custom_attr, layers=layers, **config)
    
            return model