Search code examples
tensorflowkerasgoogle-colaboratorytensorboard

How to view train_on_batch tensorboard log files generated by Google Colab?


I know how to view tensorboard plots on my local machine whilst my neural networks train using code in a local Jupyter Notebook, using the following code. What do I need to do differently when I use Google Colab to train the neural network instead? I can't see any tutorials/examples online when using train_on_batch.

After defining my model (convnet)...

convnet.compile(loss='categorical_crossentropy',                                      
                optimizer=tf.keras.optimizers.Adam(0.001),
                metrics=['accuracy']
               )

# create tensorboard graph data for the model
tb = tf.keras.callbacks.TensorBoard(log_dir='Logs/Exp_15', 
                                    histogram_freq=0, 
                                    batch_size=batch_size, 
                                    write_graph=True, 
                                    write_grads=False)
tb.set_model(convnet)

num_epochs = 3
batches_processed_counter = 0

for epoch in range(num_epochs):

    for batch in range(int(train_img.samples/batch_size)): 
        batches_processed_counter = batches_processed_counter  + 1

        # get next batch of images & labels
        X_imgs, X_labels = next(train_img) 

        #train model, get cross entropy & accuracy for batch
        train_CE, train_acc = convnet.train_on_batch(X_imgs, X_labels) 

        # validation images - just predict
        X_imgs_val, X_labels_val = next(val_img)
        val_CE, val_acc = convnet.test_on_batch(X_imgs_val, X_labels_val) 

        # create tensorboard graph info for the cross entropy loss and training accuracies
        # for every batch in every epoch (so if 5 epochs and 10 batches there should be 50 accuracies )
        tb.on_epoch_end(batches_processed_counter, {'train_loss': train_CE, 'train_acc': train_acc})

        # create tensorboard graph info for the cross entropy loss and VALIDATION accuracies
        # for every batch in every epoch (so if 5 epochs and 10 batches there should be 50 accuracies )
        tb.on_epoch_end(batches_processed_counter, {'val_loss': val_CE, 'val_acc': val_acc})

        print('epoch', epoch, 'batch', batch, 'train_CE:', train_CE, 'train_acc:', train_acc)
        print('epoch', epoch, 'batch', batch, 'val_CE:', val_CE, 'val_acc:', val_acc)

tb.on_train_end(None)

I can see that the log file has generated successfully within the Google Colab runtime. How do I view this in Tensorboard? I've seen solutions that describe downloading the log file to a local machine and viewing that in tensorboard locally but this doesn't display anything. Is there something I'm missing in my code to allow this to work on tensorboard locally? And/or an alternative solution to view the log data in Tensorboard within Google Colab?

In case its important for the details of the solution, I'm on a Mac. Also, the tutorials I've seen online show how to use Tensorboard with Google Colab when using the fit code but can't see how to modify my code which doesn't use fit but rather train_on_batch.


Solution

  • Thanks to Dr Ryan Cunningham from Manchester Metropolitan University for the solution to this problem , which was the following:

    %load_ext tensorboard
    %tensorboard --logdir './Logs'
    

    ...which allows me to view the Tensorboard plots in the Google Colab document itself, and see the plots update while the NN is training.

    So, the full set of code, to view the Tensorboard plots while the network is training is (after defining the neural network, which I've called convnet):

    # compile the neural net after defining the loss, optimisation and 
    # performance metric
    convnet.compile(loss='categorical_crossentropy',  # cross entropy is suited to 
                                                       # multi-class classification
                    optimizer=tf.keras.optimizers.Adam(0.001),
                    metrics=['accuracy']
                   )
    
    # create tensorboard graph data for the model
    tb = tf.keras.callbacks.TensorBoard(log_dir='Logs/Exp_15', 
                                        histogram_freq=0, 
                                        batch_size=batch_size, 
                                        write_graph=True, 
                                        write_grads=False)
    tb.set_model(convnet)
    
    %load_ext tensorboard
    %tensorboard --logdir './Logs'
    
    # iterate through the training set for x epochs, 
    # each time iterating through the batches,
    # for each batch, train, calculate loss & optimise weights. 
    # (mini-batch approach)
    num_epochs = 1
    batches_processed_counter = 0
    
    for epoch in range(num_epochs):
    
        for batch in range(int(train_img.samples/batch_size)): 
            batches_processed_counter = batches_processed_counter  + 1
    
            # get next batch of images & labels
            X_imgs, X_labels = next(train_img) 
    
            #train model, get cross entropy & accuracy for batch
            train_CE, train_acc = convnet.train_on_batch(X_imgs, X_labels) 
    
            # validation images - just predict
            X_imgs_val, X_labels_val = next(val_img)
            val_CE, val_acc = convnet.test_on_batch(X_imgs_val, X_labels_val) 
    
            # create tensorboard graph info for the cross entropy loss and training accuracies
            # for every batch in every epoch (so if 5 epochs and 10 batches there should be 50 accuracies )
            tb.on_epoch_end(batches_processed_counter, {'train_loss': train_CE, 'train_acc': train_acc})
    
            # create tensorboard graph info for the cross entropy loss and VALIDATION accuracies
            # for every batch in every epoch (so if 5 epochs and 10 batches there should be 50 accuracies )
            tb.on_epoch_end(batches_processed_counter, {'val_loss': val_CE, 'val_acc': val_acc})
    
            print('epoch', epoch, 'batch', batch, 'train_CE:', train_CE, 'train_acc:', train_acc)
            print('epoch', epoch, 'batch', batch, 'val_CE:', val_CE, 'val_acc:', val_acc)
    
    tb.on_train_end(None)
    
    
    

    Note: it can take a few seconds after the cell has finished running before the cell output refreshes and shows the Tensorboard plots.