Search code examples
pythontensorflowkerasimage-classificationvision-transformer

Not able to load a Vision Transformer model post training


So I am using the example of the vision transformer model for image classification provided on the Keras website. The only difference is I have added a line to save the model once it is done training as a ".keras" file.

Later I try to load the saved model and check it's configuration using "get_configuration()".

Lmodel=load_model("VITexp.keras")
Lmodel.get_config()

But the code fails to load the model and gives me the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/keras/src/ops/operation.py:208, in Operation.from_config(cls, config)
    207 try:
--> 208     return cls(**config)
    209 except Exception as e:

TypeError: Patches.__init__() got an unexpected keyword argument 'name'

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py:718, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
    717 try:
--> 718     instance = cls.from_config(inner_config)
    719 except TypeError as e:

File /opt/conda/lib/python3.10/site-packages/keras/src/ops/operation.py:210, in Operation.from_config(cls, config)
    209 except Exception as e:
--> 210     raise TypeError(
    211         f"Error when deserializing class '{cls.__name__}' using "
    212         f"config={config}.\n\nException encountered: {e}"
    213     )

TypeError: Error when deserializing class 'Patches' using config={'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}.

Exception encountered: Patches.__init__() got an unexpected keyword argument 'name'

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py:718, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
    717 try:
--> 718     instance = cls.from_config(inner_config)
    719 except TypeError as e:

File /opt/conda/lib/python3.10/site-packages/keras/src/models/model.py:517, in Model.from_config(cls, config, custom_objects)
    515     from keras.src.models.functional import functional_from_config
--> 517     return functional_from_config(
    518         cls, config, custom_objects=custom_objects
    519     )
    521 # Either the model has a custom __init__, or the config
    522 # does not contain all the information necessary to
    523 # revive a Functional model. This happens when the user creates
   (...)
    526 # In this case, we fall back to provide all config into the
    527 # constructor of the class.

File /opt/conda/lib/python3.10/site-packages/keras/src/models/functional.py:517, in functional_from_config(cls, config, custom_objects)
    516 for layer_data in config["layers"]:
--> 517     process_layer(layer_data)
    519 # Then we process nodes in order of layer depth.
    520 # Nodes that cannot yet be processed (if the inbound node
    521 # does not yet exist) are re-enqueued, and the process
    522 # is repeated until all nodes are processed.

File /opt/conda/lib/python3.10/site-packages/keras/src/models/functional.py:501, in functional_from_config.<locals>.process_layer(layer_data)
    500 else:
--> 501     layer = serialization_lib.deserialize_keras_object(
    502         layer_data, custom_objects=custom_objects
    503     )
    504 created_layers[layer_name] = layer

File /opt/conda/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py:720, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
    719 except TypeError as e:
--> 720     raise TypeError(
    721         f"{cls} could not be deserialized properly. Please"
    722         " ensure that components that are Python object"
    723         " instances (layers, models, etc.) returned by"
    724         " `get_config()` are explicitly deserialized in the"
    725         " model's `from_config()` method."
    726         f"\n\nconfig={config}.\n\nException encountered: {e}"
    727     )
    728 build_config = config.get("build_config", None)

TypeError: <class '__main__.Patches'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by `get_config()` are explicitly deserialized in the model's `from_config()` method.

config={'module': None, 'class_name': 'Patches', 'config': {'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}, 'registered_name': 'Custom>Patches', 'build_config': {'input_shape': [None, 72, 72, 3]}, 'name': 'patches_1', 'inbound_nodes': [{'args': [{'class_name': '__keras_tensor__', 'config': {'shape': [None, 72, 72, 3], 'dtype': 'float32', 'keras_history': ['data_augmentation', 0, 0]}}], 'kwargs': {}}]}.

Exception encountered: Error when deserializing class 'Patches' using config={'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}.

Exception encountered: Patches.__init__() got an unexpected keyword argument 'name'

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Cell In[11], line 1
----> 1 Lmodel=load_model("VITexp.keras")
      2 Lmodel.get_config()

File /opt/conda/lib/python3.10/site-packages/keras/src/saving/saving_api.py:176, in load_model(filepath, custom_objects, compile, safe_mode)
    173         is_keras_zip = True
    175 if is_keras_zip:
--> 176     return saving_lib.load_model(
    177         filepath,
    178         custom_objects=custom_objects,
    179         compile=compile,
    180         safe_mode=safe_mode,
    181     )
    182 if str(filepath).endswith((".h5", ".hdf5")):
    183     return legacy_h5_format.load_model_from_hdf5(
    184         filepath, custom_objects=custom_objects, compile=compile
    185     )

File /opt/conda/lib/python3.10/site-packages/keras/src/saving/saving_lib.py:152, in load_model(filepath, custom_objects, compile, safe_mode)
    147     raise ValueError(
    148         "Invalid filename: expected a `.keras` extension. "
    149         f"Received: filepath={filepath}"
    150     )
    151 with open(filepath, "rb") as f:
--> 152     return _load_model_from_fileobj(
    153         f, custom_objects, compile, safe_mode
    154     )

File /opt/conda/lib/python3.10/site-packages/keras/src/saving/saving_lib.py:170, in _load_model_from_fileobj(fileobj, custom_objects, compile, safe_mode)
    168 # Construct the model from the configuration file in the archive.
    169 with ObjectSharingScope():
--> 170     model = deserialize_keras_object(
    171         config_dict, custom_objects, safe_mode=safe_mode
    172     )
    174 all_filenames = zf.namelist()
    175 if _VARS_FNAME + ".h5" in all_filenames:

File /opt/conda/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py:720, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
    718     instance = cls.from_config(inner_config)
    719 except TypeError as e:
--> 720     raise TypeError(
    721         f"{cls} could not be deserialized properly. Please"
    722         " ensure that components that are Python object"
    723         " instances (layers, models, etc.) returned by"
    724         " `get_config()` are explicitly deserialized in the"
    725         " model's `from_config()` method."
    726         f"\n\nconfig={config}.\n\nException encountered: {e}"
    727     )
    728 build_config = config.get("build_config", None)
    729 if build_config and not instance.built:

TypeError: <class 'keras.src.models.functional.Functional'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by `get_config()` are explicitly deserialized in the model's `from_config()` method.

Exception encountered: <class '__main__.Patches'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by `get_config()` are explicitly deserialized in the model's `from_config()` method.

config={'module': None, 'class_name': 'Patches', 'config': {'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}, 'registered_name': 'Custom>Patches', 'build_config': {'input_shape': [None, 72, 72, 3]}, 'name': 'patches_1', 'inbound_nodes': [{'args': [{'class_name': '__keras_tensor__', 'config': {'shape': [None, 72, 72, 3], 'dtype': 'float32', 'keras_history': ['data_augmentation', 0, 0]}}], 'kwargs': {}}]}.

Exception encountered: Error when deserializing class 'Patches' using config={'name': 'patches_1', 'trainable': True, 'dtype': 'float32', 'patch_size': 6}.

Exception encountered: Patches.__init__() got an unexpected keyword argument 'name'

The code is copy pasted from the website except for the save and load model commands.

Please help me solve this. Is there a specific way to save these models to be accessed later in completely different notebooks? (I am using Kaggle for this code)


Solution

  • To sum up my comments under the question and give a comprehensive answer, below the code. I used the code from the link in the question, added lines by me are marked by # comments. Only the layer classes have to be modified.

    @keras.saving.register_keras_serializable()  # <- this line
    class Patches(layers.Layer):
        def __init__(self, patch_size, **kwargs):  # <- add **kwargs
            super().__init__(**kwargs)  # <- add **kwargs
            self.patch_size = patch_size
    
        def call(self, images):
            input_shape = ops.shape(images)
            batch_size = input_shape[0]
            height = input_shape[1]
            width = input_shape[2]
            channels = input_shape[3]
            num_patches_h = height // self.patch_size
            num_patches_w = width // self.patch_size
            patches = keras.ops.image.extract_patches(images, size=self.patch_size)
            patches = ops.reshape(
                patches,
                (
                    batch_size,
                    num_patches_h * num_patches_w,
                    self.patch_size * self.patch_size * channels,
                ),
            )
            return patches
    
        def get_config(self):
            config = super().get_config()
            config.update({"patch_size": self.patch_size})
            return config
    
    # ------------------------------------------------------------------
    
    
    @keras.saving.register_keras_serializable()  # this line
    class PatchEncoder(layers.Layer):
        def __init__(self, num_patches, projection_dim, **kwargs):  # <- add **kwargs
            super().__init__(**kwargs)  # <- add **kwargs
            self.num_patches = num_patches
            self.projection_dim = projection_dim  # save projection_dim
            print(f'num_patches: {num_patches}, proj. dim: {projection_dim}')
            self.projection = layers.Dense(units=projection_dim)
            self.position_embedding = layers.Embedding(
                input_dim=num_patches, output_dim=projection_dim
            )
    
        def build(self, input_shape):  # add build method (this threw only a warning)
            super().build(input_shape)
    
        def call(self, patch):
            positions = ops.expand_dims(
                ops.arange(start=0, stop=self.num_patches, step=1), axis=0
            )
            projected_patches = self.projection(patch)
            encoded = projected_patches + self.position_embedding(positions)
            return encoded
    
        def get_config(self):
            config = super().get_config()
            config.update({"num_patches": self.num_patches})
            config.update({"projection_dim": self.projection_dim})  # this line
            return config
    

    Short explanation of the added code lines:

    @keras.saving.register_keras_serializable()
    

    This decorator registers the custom layer for Keras to know about, registering it in a master list.

    **kwargs
    

    Catch unknown (to the user) keyword arguments the __init__() method gets and give them to the super() call. In this case, __init__() got the parameter name, because every Layer class gets one. But name was initially not an expected parameter.

    self.projection_dim = projection_dim
    # ...
    config.update({"projection_dim": self.projection_dim})
    

    This two lines save projection_dim to the config of the PatchEncoder layer. This is done to use the set parameter when loading the layer again.