This question is similar to Tensorflow Keras modify model variable from callback. I am unable to get the solution there to work (maybe there have been changes in TensorFlow 2.x since the solution was posted).
Below is demo code. I apologise if there is a typo.
I want to use a callback to update a non trainable variable (weighted_add_layer.weight
) that affects the output of the layer.
I have tried many variants such as putting tf.keras.backend.set_value(weighted_add_layer.weight, value)
in update function
.
In all cases, after the model is compiled, fit
uses the value of weighted_add_layer.weight
at the time of compilation and does not update the value later.
class WeightedAddLayer(tf.keras.layers.Layer):
def __init__(self, weight=0.00, *args, **kwargs):
super(WeightedAddLayer, self).__init__(*args, **kwargs)
self.weight = tf.Variable(0., trainable=False)
def add(self, inputA, inputB):
return (self.weight * inputA + self.weight * inputB)
def update(self, weight):
tf.keras.backend.set_value(self.weight, weight)
input_A = tfkl.Input(
shape=(32),
batch_size=32,
)
input_B = tfkl.Input(
shape=(32),
batch_size=32,
)
weighted_add_layer = WeightedAddLayer()
output = weighted_add_layer.add(input_A, input_B)
model = tfk.Model(
inputs=[input_A, input_B],
outputs=[output],
)
model.compile(
optimizer='adam', loss=losses.MeanSquaredError()
)
# Custom callback function
def update_fun(epoch, steps=50):
weighted_add_layer.update(
tf.clip_by_value(
epoch / steps,
clip_value_min=tf.constant(0.0),
clip_value_max=tf.constant(1.0),)
)
# Custom callback
update_callback = tfk.callbacks.LambdaCallback(
on_epoch_begin=lambda epoch, logs: update_fun(epoch)
)
# train model
history = model.fit(
x=train_data,
epochs=EPOCHS,
validation_data=valid_data,
callbacks=[update_callback],
)
Any suggestions? Thanks much!
This could be an issue with TensorFlow 2.11.0 or my installation or something else I am missing but the use of lambda callbacks was both extremely unstable with my code base and bug checked constantly, and did not do what I wanted. It also led to odd behaviour that made it seem like there was a memory leak. The code for the complete model is very complex and I don't have the time to debug so I am sharing this information with a big FWIW caveat.
The code in Is there a way to make a layer behave differently during forward pass for model.fit() and model.evaluate() in a customised Keras model? works. Some pointers:
a. You must have the tf.variable sit inside a layer and non trainable. I could not get this approach to work with a tf.variable outside a layer. That is not a big deal as one can always define a trivial layer that only scales an input or does some simple computation and use that layer to complete a task. I found that tf.variables outside a layer got optimised away by the compiler so there was no way to update post compilation.
b. The use of assign works well as an update device. I tried other approaches but I ended up with assign.
Here is a callback subclass that is consistent with the demo code. Note that when using the class you have to instantiate an instance of the class when calling fit
. You cannot pass the name of the callback. Also note that this is not my real code but something I wrote to be consistent with the demo code above. It has not been tested and may have errors/typos.
class update_callback(tf.keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None, steps=50):
update_value = tf.clip_by_value(
tf.cast((epoch + 1) / steps, dtype=tf.float32),
clip_value_min=tf.constant(0.0, dtype=tf.float32),
clip_value_max=tf.constant(1.0, dtype=tf.float32),
) # change this to what you want
weighted_add_layer.weight.assign(update_value) #assign the update