Search code examples
python-3.xtensorflowkerastensorboard

Customized tf.keras.keras.callbacks.TensorBoard does not work well in tensorflow version >= 1.15.0


The following customized TensorBoard, cloned from this github repo, aiming to store learning rate and the so-called KL weight at the end of every batch, works very well in tensorflow version <= 1.12.0, but does not do what it should at every batch in version >= 1.15.0. How can I fix it?

class TensorBoardLR(TensorBoard):
    """ A modification to the Tensorboard callback to also include the scalars of learning rate and KL weight"""
    def __init__(self, *args, **kwargs):
        self.kl_weight = kwargs.pop('kl_weight')
        super().__init__(*args, **kwargs)
        self.count = 0

    def on_batch_end(self, batch, logs=None):
        logs.update({'lr': K.eval(self.model.optimizer.lr),
                     'kl_weight': K.eval(self.kl_weight)})
        super().on_batch_end(batch, logs)

Solution

  • Change the name of your method from on_batch_end to on_train_batch_end.

    They kept the method for legacy code, https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/keras/callbacks/Callback It is also a difference between keras, and tensorflow.keras.