Search code examples
pythontensorflowkerasloss

Log Keras metrics for each batch as per Keras example for the loss


In the Keras doc, there is an example where a custom callback is created to log the loss for each batch. This has worked fine for me, however I also want to log metrics that I add.

For example for this code:

optimizer = Adam()
loss = losses.categorical_crossentropy
metric = ["accuracy"]

model.compile(optimizer=optimizer,
              loss=loss,
              metrics=metric)


class LossHistory(Callback):
    def on_train_begin(self, logs={}):
        self.losses = []

    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))

loss_history = LossHistory()

history = model.fit(training_data, training_labels,
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=2,
                    validation_data=(val_data, val_labels),
                    callbacks=[loss_history])

I can't figure out how to get access to the metrics.


Solution

  • The metric history is stored inside loss_history.losses:

    def on_batch_end(self, batch, logs={}):
      self.losses.append(logs.get('loss'))
    

    This method will get called at the end of every batch and just appends the loss metrics into self.losses so once training has completed you can just access this list directly with loss_history.losses.

    I should also add that if you wanted to include accuracy, for example, you could also do something like:

    class LossHistory(Callback):
        def on_train_begin(self, logs={}):
            self.losses = []
            self.accuracy= []
    
        def on_batch_end(self, batch, logs={}):
            self.losses.append(logs.get('loss'))
            self.accuracy.append(logs.get('accuracy'))
    

    then subsequently access it with:

    loss_history.accuracy