Search code examples
pythontensorflowsavekeras-layer

NotImplementedError: Layer attention has arguments in `__init__` and therefore must override `get_config`


I have implemented a custom attention layer as suggested in this link: How to add attention layer to a Bi-LSTM

class attention(Layer):    

      def __init__(self, return_sequences=True):
         self.return_sequences = return_sequences
         super(attention,self).__init__()
        
      def build(self, input_shape):
        
        self.W=self.add_weight(name="att_weight", shape=(input_shape[-1],1),
                               initializer="normal")
        self.b=self.add_weight(name="att_bias", shape=(input_shape[1],1),
                               initializer="zeros")
        
        super(attention,self).build(input_shape)
        
    def call(self, x):
        
        e = K.tanh(K.dot(x,self.W)+self.b)
        a = K.softmax(e, axis=1)
        output = x*a
        
        if self.return_sequences:
            return output
        
        return K.sum(output, axis=1)

The code ran however I got this error when the model needed to save.

NotImplementedError: Layer attention has arguments in __init__ and therefore must override get_config.

Some reviews suggest to override the get_config.

"This error lets you know that tensorflow can't save your model, because it won't be able to load it. Specifically, it won't be able to reinstantiate your custom Layer classes.

To solve this, just override their get_config method according to the new arguments you've added."

Link to review: NotImplementedError: Layers with arguments in `__init__` must override `get_config`

My question is ,based on the custom attention layer above, how do I code the get_config to solve this error?


Solution

  • You need a config method like this:

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'return_sequences': self.return_sequences 
        })
        return config
    

    All the info needed was in the other post that you linked.