Search code examples
tensorflowmachine-learningkerastensorflow2.0tensorboard

Resuming training with fit resets batch step to 0


I have recently rewritten my TensorFlow 2 custom training loop to use fit, using the ModelCheckpoint callback to manage the checkpointing I was previously doing manually in the loop. This is all working nicely, but I have one issue which I've been struggling with: resuming training at the correct batch step. I use the TensorBoard callback with update_freq=50, and when I resume training (loading the weights from a saved checkpoint) I typically see things like the following:

enter image description here

The above results from 2 runs, with a single epoch containing 250 batch steps (just a toy dataset). The top line is the first run of 2 epochs, ending after 500 steps (the summaries update every 50 steps, but not after the last, presumably, so the last step at 500 is missing). The straight line in the middle is just the graph line being drawn to the start of the second run, represented by the bottom line. I ran that for 3 epochs, hence 750 batch steps (250 * 3).

The problem is that the step count starts again from 0 at each restart of the training. How can I fix this? Presumably the TensorBoard callback is tracking the steps in each epoch from 0... The fit method has an initial_epoch parameter, which I use to restart at the correct epoch, but is it possible to track the global batch step? I've seen global_step in older (pre-TF2) code, was that used to implement this?


Solution

  • OK, it seems the way to manage this is through the concept of runs, which amounts to creating separate log subdirectories for different training sessions - as demonstrated in this Colab notebook.