Search code examples
pythontensorflowkeraslayereager-execution

How to update parameter at each epoch within an intermediate Layer between training runs ? (tensorflow eager execution)


I have a sequential keras model and there i have a custom Layer similar to the following example named 'CounterLayer'. I am using tensorflow 2.0 (eager execution)

class CounterLayer(tf.keras.layers.Layer):
  def __init__(self, stateful=False,**kwargs):
    self.stateful = stateful
    super(CounterLayer, self).__init__(**kwargs)


  def build(self, input_shape):
    self.count = tf.keras.backend.variable(0, name="count")
    super(CounterLayer, self).build(input_shape)

  def call(self, input):
    updates = []
    updates.append((self.count, self.count+1))
    self.add_update(updates)
    tf.print('-------------')
    tf.print(self.count)
    return input

when i run this for example epoch=5 or something, the value of self.count does not get updated with each run. It always remains the same. I got this example from https://stackoverflow.com/a/41710515/10645817 here. I need something almost similar to this but i was wondering does this work in eager execution of tensorflow or what would i have to do to get the expected output.

I have been trying to implement this for quite a while but could not figure it out. Can somebody help me please. Thank you...


Solution

  • yes, my issue got resolved. I have come across some of the built-in methods to update this sort of variables (which is to maintain the persistent state in between epochs like my case mentioned above). Basically what i needed to do is for example:

      def build(self, input_shape):
        self.count = tf.Variable(0, dtype=tf.float32, trainable=False)
        super(CounterLayer, self).build(input_shape)
    
      def call(self, input):
        ............
        self.count.assign_add(1)
        ............
        return input
    

    One can use to calculate the updated value in the call function and can also assign it by calling self.count.assign(some_updated_value). The details to this sort of operations are available in https://www.tensorflow.org/api_docs/python/tf/Variable. Thanks.