Search code examples
tensorflowgoogle-colaboratorytensorboard

Tensorboard not updating by batch in google colab


I'm using tensorboard in google colab, it's works fine if i want to track the epochs. However, i want to track the accuracy/loss by batch. I'm trying it using the getting started at documentation https://www.tensorflow.org/tensorboard/get_started but if i change the argument update_freq by update_freq="batch" it doesn't work. I have tried in my local pc and it works. Any idea of what is happening?

Using tensorboard 2.8.0 and tensorflow 2.8.0

Code (running in colab)

%load_ext tensorboard
import tensorflow as tf
import datetime
!rm -rf ./logs/ 
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

def create_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
  ])
model = create_model()
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

log_dir = "logs/fit_2/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, update_freq="batch")

model.fit(x=x_train, 
          y=y_train, 
          epochs=5, 
          validation_data=(x_test, y_test), 
          callbacks=[tensorboard_callback])

I've tried to use a integer and it doesn't work either. In my local computer i've no problems.


Solution

  • The change after TensorFlow 2.3 made the batch-level summaries part of the Model.train_function rather than something that the TensorBoard callback creates itself. This resulted in a 2x improvement in speed for many small models in Model.fit, but it does have the side effect that calling TensorBoard.on_train_batch_end(my_batch, my_metrics) in a custom training loop will no longer log batch-level metrics.

    This issue was discussed in one of the GitHub issue.

    There can be a workaround by creating a custom callback like LambdaCallback.

    I have modified the last part of your code to explicitly add scalar values of batch_loss and batch_accuracy using tf.summary.scalar() to be shown in tensorboard logs.

    The code module is as follows:

    model = create_model()
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    
    from keras.callbacks import LambdaCallback
    
    def batchOutput(batch, logs):
        tf.summary.scalar('batch_loss', data=logs['loss'], step=batch)
        tf.summary.scalar('batch_accuracy', data=logs['accuracy'], step=batch)
        return batch
    
    batchLogCallback = LambdaCallback(on_batch_end=batchOutput)
    
    
    log_dir = "logs/fit_2/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs', update_freq='batch')
    
    model.fit(x=x_train, 
              y=y_train, 
              epochs=1, 
              validation_data=(x_test, y_test), 
              callbacks=[tensorboard_callback, batchLogCallback])
    

    I tried this in Colab as well it worked.