Search code examples
pythontensorflowkerasdeep-learningcallback

How to print one log line per every 10 epochs when training models with tensorflow keras?


When I fit the model with:

model.fit(X, y, epochs=40, batch_size=32, validation_split=0.2, verbose=2) it prints one log line for each epoch as:

Epoch 1/100
0s - loss: 0.2506 - acc: 0.5750 - val_loss: 0.2501 - val_acc: 0.3750
Epoch 2/100
0s - loss: 0.2487 - acc: 0.6250 - val_loss: 0.2498 - val_acc: 0.6250
Epoch 3/100
0s - loss: 0.2495 - acc: 0.5750 - val_loss: 0.2496 - val_acc: 0.6250
.....

How can I print the log line per very 10 epochs as follows?

Epoch 10/100
0s - loss: 0.2506 - acc: 0.5750 - val_loss: 0.2501 - val_acc: 0.3750
Epoch 20/100
0s - loss: 0.2487 - acc: 0.6250 - val_loss: 0.2498 - val_acc: 0.6250
Epoch 30/100
0s - loss: 0.2495 - acc: 0.5750 - val_loss: 0.2496 - val_acc: 0.6250
.....


Solution

  • This callback will create and write on a log text file what you want:

    log_path = "text_file_name.txt"  # it will be created automatically
    
    
    class print_training_on_text_every_10_epochs_Callback(Callback):
        def __init__(self, logpath):
            self.logpath = logpath
    
        def on_epoch_end(self, epoch, logs=None):
            with open(self.logpath, "a") as writefile:  # put log_path here
                with redirect_stdout(writefile):
                    if (int(epoch) % 10) == 0:
                        print(
                            f"Epoch: {epoch:>3}"
                            + f" | Loss: {logs['loss']:.4e}"
                            + f" | Accuracy: {logs['accuracy']:.4e}"
                            + f" | Validation loss: {logs['val_loss']:.4e}"
                            + f" | Validation accuracy: {logs['val_accuracy']:.4e}"
                        )
                        writefile.write("\n")
    
    
    my_callbacks = [
        print_training_on_text_every_10_epochs_Callback(logpath=log_path),
    ]
    

    You want to call it like this.

    model.fit(
        training_dataset,
        epochs=60,
        verbose=0, # Must be zero, else it would still print log per epoch.
        validation_data=validation_dataset, 
        callbacks=my_callbacks,
    )
    

    The text file will be updated only after 10 epochs have passed

    This is what i get on the text file

    Epoch:   0 | Loss: 5.3454e+00 | Valid loss: 4.2420e-01
    
    Epoch:  10 | Loss: 3.1342e-02 | Valid loss: 3.4554e-02
    
    Epoch:  20 | Loss: 1.6330e-02 | Valid loss: 2.2512e-02
    

    The first epoch is numbered 0, the second 1 and so on.