Search code examples
pythontensorflowkerasdeep-learningclassification

can't load trained keras model with custom regularization class


I'm training the PointNet3D object classification model with my own dataset following the Tutorial here in Keras: https://keras.io/examples/vision/pointnet/#point-cloud-classification-with-pointnet

Now for the training part, I've been able to do everything just fine but after training I'm facing issues loading the trained model. The main issue I think is with this part below, OrthogonalRegularizer class object might not be registered properly when I'm saving the model:


@keras.saving.register_keras_serializable('OrthogonalRegularizer')
class OrthogonalRegularizer(keras.regularizers.Regularizer):

    def __init__(self, num_features, **kwargs):
        super(OrthogonalRegularizer, self).__init__(**kwargs)
        self.num_features = num_features
        self.l2reg = 0.001
        self.eye = tf.eye(num_features)

    def __call__(self, x):
        x = tf.reshape(x, (-1, self.num_features, self.num_features))
        xxt = tf.tensordot(x, x, axes=(2, 2))
        xxt = tf.reshape(xxt, (-1, self.num_features, self.num_features))
        return tf.math.reduce_sum(self.l2reg * tf.square(xxt - self.eye))

    def get_config(self):
        config = {}
        config.update({"num_features": self.num_features, "l2reg": self.l2reg, "eye": self.eye})
        return config

def tnet(inputs, num_features):
    # Initialise bias as the identity matrix
    bias = keras.initializers.Constant(np.eye(num_features).flatten())
    reg = OrthogonalRegularizer(num_features)

    x = conv_bn(inputs, 32)
    x = conv_bn(x, 64)
    x = conv_bn(x, 512)
    x = layers.GlobalMaxPooling1D()(x)
    x = dense_bn(x, 256)
    x = dense_bn(x, 128)
    x = layers.Dense(
        num_features * num_features,
        kernel_initializer="zeros",
        bias_initializer=bias,
        activity_regularizer=reg,
    )(x)
    feat_T = layers.Reshape((num_features, num_features))(x)
    # Apply affine transformation to input features
    return layers.Dot(axes=(2, 1))([inputs, feat_T])

After training when I try to load the model by the following, I see the following error:

model.save('my_model.h5')
model = keras.models.load_model('my_model.h5', custom_objects={'OrthogonalRegularizer': OrthogonalRegularizer})

The error message:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-18-05f700f433a8> in <cell line: 2>()
      1 model.save('my_model.h5')
----> 2 model = keras.models.load_model('my_model.h5', custom_objects={'OrthogonalRegularizer': OrthogonalRegularizer})

2 frames
/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_api.py in load_model(filepath, custom_objects, compile, safe_mode, **kwargs)
    260 
    261     # Legacy case.
--> 262     return legacy_sm_saving_lib.load_model(
    263         filepath, custom_objects=custom_objects, compile=compile, **kwargs
    264     )

/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs)
     68             # To get the full stack trace, call:
     69             # `tf.debugging.disable_traceback_filtering()`
---> 70             raise e.with_traceback(filtered_tb) from None
     71         finally:
     72             del filtered_tb

/usr/local/lib/python3.10/dist-packages/keras/src/engine/base_layer.py in from_config(cls, config)
    868             return cls(**config)
    869         except Exception as e:
--> 870             raise TypeError(
    871                 f"Error when deserializing class '{cls.__name__}' using "
    872                 f"config={config}.\n\nException encountered: {e}"

TypeError: Error when deserializing class 'Dense' using config={'name': 'dense_2', 
'trainable': True, 'dtype': 'float32', 'units': 9, 'activation': 'linear', 'use_bias': 

True, 'kernel_initializer': {'module': 'keras.initializers', 'class_name': 'Zeros', 
'config': {}, 'registered_name': None}, 'bias_initializer': {'module': 
'keras.initializers', 'class_name': 'Constant', 'config': {'value': {'class_name': 
'__numpy__', 'config': {'value': [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], 'dtype': 
'float64'}}}, 'registered_name': None}, 'kernel_regularizer': None, 'bias_regularizer': 
None, 'activity_regularizer': {'module': None, 'class_name': 'OrthogonalRegularizer', 
'config': {'num_features': 3, 'l2reg': 0.001, 'eye': {'class_name': '__tensor__', 
'config': {'value': [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], 'dtype': 
'float32'}}}, 'registered_name': 'OrthogonalRegularizer>OrthogonalRegularizer'}, 
'kernel_constraint': None, 'bias_constraint': None}.

Exception encountered: object.__init__() takes exactly one argument (the instance to initialize)

What I understand so far is that while saving I'm not able to save the OrthogonalRegularizer class object properly. Please let me know what I'm doing wrong.

The minimal version of the code is uploaded here in this collab notebook: https://colab.research.google.com/drive/1akpfoOBVAWThsZl7moYywuZIuXt_vWCU?usp=sharing

One possible similar question is this: Load customized regularizer in Keras


Solution

  • You do not need to call

    super(OrthogonalRegularizer, self).__init__(**kwargs)
    

    because the constructor in keras.regularizers.Regularizer is not defined. Also there is no need to store the unserializable tensor eye: self.eye. Moreover, it would be better to create this tensor closer to where it is used.

    The modified code should look like:

    @keras.saving.register_keras_serializable('OrthogonalRegularizer')
    class OrthogonalRegularizer(keras.regularizers.Regularizer):
    
        def __init__(self, num_features, **kwargs):
            self.num_features = num_features
            self.l2reg = 0.001
    
        def call(self, x):
            x = tf.reshape(x, (-1, self.num_features, self.num_features))
            xxt = tf.tensordot(x, x, axes=(2, 2))
            xxt = tf.reshape(xxt, (-1, self.num_features, self.num_features))
            eye = tf.eye(self.num_features)
            return tf.math.reduce_sum(self.l2reg * tf.square(xxt - eye))
    
    
        def get_config(self):
            return {"num_features": self.num_features, "l2reg": self.l2reg}
    

    Also it is better to use newer type of save and load model:

    model.save('my_model')
    model = keras.models.load_model('my_model', custom_objects={'OrthogonalRegularizer': OrthogonalRegularizer})