Search code examples
pythonkerastensorflow2.0

Keras Custom Layer gives errors when saving the full model


class ConstLayer(tf.keras.layers.Layer):
    def __init__(self, x, **kwargs):
        super(ConstLayer, self).__init__(**kwargs)
        self.x = tf.Variable(x, trainable=False)

    def call(self, input):
        return self.x

    def get_config(self):
        #Note: all original model has eager execution disabled
        config = super(ConstLayer, self).get_config()
        config['x'] = self.x
        return config
    


model_test_const_layer = keras.Sequential([
    keras.Input(shape=(784)),
    ConstLayer([[1.,1.]], name="anchors"),
    keras.layers.Dense(10),
])

model_test_const_layer.summary()
model_test_const_layer.save("../models/my_model_test_constlayer.h5")
del model_test_const_layer
model_test_const_layer = keras.models.load_model("../models/my_model_test_constlayer.h5",custom_objects={'ConstLayer': ConstLayer,})
model_test_const_layer.summary()

This code is a sandbox replication of an error given by a larger Keras model with a RESNet 101 backbone.

Errors: If the model includes the custom layer ConstLayer:

  • without this line: config['x'] = self.x error when loading the saved model with keras.models.load_model: TypeError: __init__() missing 1 required positional argument: 'x'

  • with config['x'] = self.x error: NotImplementedError: deepcopy() is only available when eager execution is enabled. Note: The larger model, requires eager execution disabled tf.compat.v1.disable_eager_execution()

Any help and clues are greatly appreciated!


Solution

  • As far as I understand it, TF has problems with copying variables. Just save the original value / config passed to the layer instead:

    import tensorflow as tf
    import tensorflow.keras as keras
    
    tf.compat.v1.disable_eager_execution()
    
    class ConstLayer(tf.keras.layers.Layer):
        def __init__(self, x, **kwargs):
            super(ConstLayer, self).__init__(**kwargs)
            self._config = {'x': x}
            self.x = tf.Variable(x, trainable=False)
    
        def call(self, input):
            return self.x
    
        def get_config(self):
            #Note: all original model has eager execution disabled
            config = {
                **super(ConstLayer, self).get_config(),
                **self._config
            }
            return config
    
    
    model_test_const_layer = keras.Sequential([
        keras.Input(shape=(784)),
        ConstLayer([[1., 1.]], name="anchors"),
        keras.layers.Dense(10),
    ])
    
    model_test_const_layer.summary()
    model_test_const_layer.save("../models/my_model_test_constlayer.h5")
    del model_test_const_layer
    model_test_const_layer = keras.models.load_model(
        "../models/my_model_test_constlayer.h5", custom_objects={'ConstLayer': ConstLayer, })
    model_test_const_layer.summary()