Search code examples
pythontensorflowkeras

Update a non trainable variable at the beginning of each epoch


This question is similar to Tensorflow Keras modify model variable from callback. I am unable to get the solution there to work (maybe there have been changes in TensorFlow 2.x since the solution was posted).

Below is demo code. I apologise if there is a typo.

I want to use a callback to update a non trainable variable (weighted_add_layer.weight) that affects the output of the layer.

I have tried many variants such as putting tf.keras.backend.set_value(weighted_add_layer.weight, value) in update function.

In all cases, after the model is compiled, fit uses the value of weighted_add_layer.weight at the time of compilation and does not update the value later.

class WeightedAddLayer(tf.keras.layers.Layer):
    def __init__(self, weight=0.00, *args, **kwargs):
        super(WeightedAddLayer, self).__init__(*args, **kwargs)
        self.weight = tf.Variable(0., trainable=False)

    def add(self, inputA, inputB):
        return (self.weight * inputA + self.weight * inputB)

    def update(self, weight):
        tf.keras.backend.set_value(self.weight, weight)
        
input_A = tfkl.Input(
    shape=(32),
    batch_size=32,
)

input_B = tfkl.Input(
    shape=(32),
    batch_size=32,
)

weighted_add_layer = WeightedAddLayer()

output = weighted_add_layer.add(input_A, input_B)

model = tfk.Model(
    inputs=[input_A, input_B],
    outputs=[output],
)
model.compile(
    optimizer='adam', loss=losses.MeanSquaredError()
)

# Custom callback function
def update_fun(epoch, steps=50):
    weighted_add_layer.update(
      tf.clip_by_value(
          epoch / steps,
          clip_value_min=tf.constant(0.0),
          clip_value_max=tf.constant(1.0),)
    )
    

# Custom callback
update_callback = tfk.callbacks.LambdaCallback(
    on_epoch_begin=lambda epoch, logs: update_fun(epoch)
)

# train model
history = model.fit(
    x=train_data,
    epochs=EPOCHS,
    validation_data=valid_data,
    callbacks=[update_callback],
)

Any suggestions? Thanks much!


Solution

    1. This could be an issue with TensorFlow 2.11.0 or my installation or something else I am missing but the use of lambda callbacks was both extremely unstable with my code base and bug checked constantly, and did not do what I wanted. It also led to odd behaviour that made it seem like there was a memory leak. The code for the complete model is very complex and I don't have the time to debug so I am sharing this information with a big FWIW caveat.

    2. The code in Is there a way to make a layer behave differently during forward pass for model.fit() and model.evaluate() in a customised Keras model? works. Some pointers:

    a. You must have the tf.variable sit inside a layer and non trainable. I could not get this approach to work with a tf.variable outside a layer. That is not a big deal as one can always define a trivial layer that only scales an input or does some simple computation and use that layer to complete a task. I found that tf.variables outside a layer got optimised away by the compiler so there was no way to update post compilation.

    b. The use of assign works well as an update device. I tried other approaches but I ended up with assign.

    Here is a callback subclass that is consistent with the demo code. Note that when using the class you have to instantiate an instance of the class when calling fit. You cannot pass the name of the callback. Also note that this is not my real code but something I wrote to be consistent with the demo code above. It has not been tested and may have errors/typos.

    class update_callback(tf.keras.callbacks.Callback):
        def on_epoch_begin(self, epoch, logs=None, steps=50):
            update_value = tf.clip_by_value(
                tf.cast((epoch + 1) / steps, dtype=tf.float32),
                clip_value_min=tf.constant(0.0, dtype=tf.float32),
                clip_value_max=tf.constant(1.0, dtype=tf.float32),
            ) # change this to what you want
            weighted_add_layer.weight.assign(update_value) #assign the update