Search code examples
rkeras

how can I find the number of epochs for which keras model was trained?


How can I find the number of epochs for which keras model was trained?

  1. I use callback_early_stopping() to stop the training early to avoid overfitting.

  2. I have been using callback_csv_logger() to log training performances. But sometimes, I train 100s of keras models and it does not make sense to log entire training just to know the number of epochs for which each model.

library(keras)
library(kerasR)
library(tidyverse)


# Data
x = matrix(data = runif(30000), nrow = 10000, ncol = 3)
y = ifelse(rowSums(x) > 1.5 + runif(10000), 1, 0)
y = to_categorical(y)

# keras model
model <- keras_model_sequential() %>%   
  layer_dense(units = 50, activation = "relu", input_shape = ncol(x)) %>%
  layer_dense(units = ncol(y), activation = "softmax")

model %>%
  compile(loss = "categorical_crossentropy", 
          optimizer = optimizer_rmsprop(), 
          metrics = "accuracy")

model %>% 
  fit(x, y, 
      epochs = 1000,
      batch_size = 128,
      validation_split = 0.2, 
      callbacks = callback_early_stopping(monitor = "val_loss", patience = 5),
      verbose = 1)

Solution

  • To print the number of epochs (whetevere you want) you can use a callback. Here's an example:

    class print_log_Callback(Callback):
      def __init__(self, logpath, steps):
        self.logpath = logpath
        self.losslst = np.zeros(steps)
    
      def on_train_batch_end(self, batch, logs=None):
        self.losslst[batch] = logs["loss"]
        with open(logpath, 'a') as writefile:
          with redirect_stdout(writefile):
            print("For batch {}, loss is {:7.2f}.".format(batch, logs["loss"]))
            writefile.write("\n")
    
      def on_test_batch_end(self, batch, logs=None):
        with open(logpath, 'a') as writefile:
          with redirect_stdout(writefile):
            print("For batch {}, val_loss is {:7.2f}.".format(batch, logs["loss"]))
            writefile.write("\n")
    
      def on_epoch_end(self, epoch, logs=None):
        with open(logpath, 'a') as writefile:
          with redirect_stdout(writefile):
            print("The val_loss  for epoch {} is {:7.2f}.".format(epoch, logs['val_loss']))
            writefile.write("\n")
            print("The mean train loss is: ", np.mean(self.losslst))
            writefile.write("\n")
            writefile.write("\n")
    
        self.losslst = np.zeros(steps)
    

    you call it like this:

    print_log_Callback(logpath=logpath, steps=int(steps))
    

    where logpath is the path of the text file where you are writing the code and steps is the number of steps.

    This callback basically prints on a text file the entire history of the network.

    loss after every batch and after every epoch end.

    If you need only the epoch you could use just the method on_epoch_end and remove everything else.

    If you want to print the loss after every epoch you can use this modified version:

    class print_log_Callback(Callback):
      def __init__(self, logpath, steps):
        self.logpath = logpath
        self.losslst = np.zeros(steps)
    
      def on_train_batch_end(self, batch, logs=None):
        self.losslst[batch] = logs["loss"]
    
      def on_epoch_end(self, epoch, logs=None):
        with open(logpath, 'a') as writefile:
          with redirect_stdout(writefile):
            print("The val_loss  for epoch {} is {:7.2f}.".format(epoch, logs['val_loss']))
            writefile.write("\n")
            print("The mean train loss is: ", np.mean(self.losslst))
            writefile.write("\n")
            writefile.write("\n")
    
        self.losslst = np.zeros(steps)
    

    you can modify this callback to print also the metric: just print logs["accuracy"] for example.