Search code examples
tensorflowkeras

I can't save the model when using Custom Layer


I need to use an attention layer that returns 3D output, so I built this layer and used it in my model and it was ok, but when I tried to save the model it failed:

class attention(Model,Layer):
    def __init__(self, return_sequences=True,**kwargs):
        super(attention, self).__init__(**kwargs)
        self.return_sequences = return_sequences
  
    def build(self, input_shape):
        self.W=self.add_weight(name="att_weight", shape=(input_shape[-1],1),
                               initializer="normal")
        self.b=self.add_weight(name="att_bias", shape=(input_shape[1],1),
                               initializer="normal")
        super(attention,self).build(input_shape)

    def call(self, x):
        e = K.tanh(K.dot(x,self.W)+self.b)
        a = K.softmax(e, axis=1)
        output = x*a
        if self.return_sequences:
            return output
        return K.sum(output, axis=1)

    def get_config(self):
        config = super(attention, self).get_config().copy()
        config.update({"return_sequences": self.return_sequences})
        return config

When I try to save my model I get the following error:

tf.keras.models.save_model(model,filepath+'/my_h5_model.h5',save_traces=False)
---------------------------------------------------------------------------

NotImplementedError                       Traceback (most recent call last)

<ipython-input-7-53e98aa74c0b> in <module>()
      1 filepath='/content/drive/MyDrive/Colab Notebooks/AE/models'
----> 2 tf.keras.models.save_model(model,filepath+'/my_h5_model.h5',save_traces=False)
      3 #model.save(filepath+'/my_h5_model.h5',save_traces=False)

1 frames

/content/drive/MyDrive/Colab Notebooks/AE/layer.py in get_config(self)
     32 
     33     def get_config(self):
---> 34         config = super(attention, self).get_config().copy()
     35         config.update({"return_sequences": self.return_sequences,'name':self.name})
     36         return config

NotImplementedError: 

Solution

  • I have executed the above code by removing the model argument from the class definition and did not face any error while saving the model

    class attention(tf.keras.layers.Layer):
        def __init__(self, return_sequences=True,**kwargs):
            super(attention, self).__init__(**kwargs)
            self.return_sequences = return_sequences
      
        def build(self, input_shape):
            self.W=self.add_weight(name="att_weight", shape=(input_shape[-1],1),
                                   initializer="normal")
            self.b=self.add_weight(name="att_bias", shape=(input_shape[1],1),
                                   initializer="normal")
            super(attention,self).build(input_shape)
    
        def call(self, x):
            e = K.tanh(K.dot(x,self.W)+self.b)
            a = K.softmax(e, axis=1)
            output = x*a
            if self.return_sequences:
                return output
            return K.sum(output, axis=1)
    
        def get_config(self):
            config = super(attention, self).get_config().copy()
            config.update({"return_sequences": self.return_sequences})
            return config
    

    Please refer to this gist for working code example. Thank You.