Search code examples
tensorflowkerasneural-networkepoch

Save history of model.fit for different epochs


I was training my model with epoch=10. I again retrained with epoch=3. and again epoch 5. so for every time i train model with epoch=10, 3, 5. I want to combine the history of all the 3. FOr example, let h1 = history of model.fit for epoch=10, h2 = history of model.fit for epoch=3, h3 = history of model.fit for epoch=5.

Now in variable h, i want h1 + h2 + h3. All history to be appended to single variable so that i can plot some graphs.

the code is,

start_time = time.time()

model.fit(x=X_train, y=y_train, batch_size=32, epochs=10, validation_data=(X_val, y_val), callbacks=[tensorboard, checkpoint])

end_time = time.time()
execution_time = (end_time - start_time)
print(f"Elapsed time: {hms_string(execution_time)}")


start_time = time.time()

model.fit(x=X_train, y=y_train, batch_size=32, epochs=3, validation_data=(X_val, y_val), callbacks=[tensorboard, checkpoint])

end_time = time.time()
execution_time = (end_time - start_time)
print(f"Elapsed time: {hms_string(execution_time)}")

start_time = time.time()

model.fit(x=X_train, y=y_train, batch_size=32, epochs=5, validation_data=(X_val, y_val), callbacks=[tensorboard, checkpoint])

end_time = time.time()
execution_time = (end_time - start_time)
print(f"Elapsed time: {hms_string(execution_time)}")


Solution

  • You can achieve this functionality by creating a class which sub-classes tf.keras.callbacks.Callback and use the object of that class as callback to model.fit.

    import csv
    import tensorflow.keras.backend as K
    from tensorflow import keras
    import os
    
    model_directory='./xyz' # directory to save model history after every epoch 
    
    class StoreModelHistory(keras.callbacks.Callback):
    
      def on_epoch_end(self,batch,logs=None):
        if ('lr' not in logs.keys()):
          logs.setdefault('lr',0)
          logs['lr'] = K.get_value(self.model.optimizer.lr)
    
        if not ('model_history.csv' in os.listdir(model_directory)):
          with open(model_directory+'model_history.csv','a') as f:
            y=csv.DictWriter(f,logs.keys())
            y.writeheader()
    
        with open(model_directory+'model_history.csv','a') as f:
          y=csv.DictWriter(f,logs.keys())
          y.writerow(logs)
    
    
    model.fit(...,callbacks=[StoreModelHistory()])
    

    Then you can load the csv file and plot model's loss, learning rate, metrics, etc.

    import pandas as pd
    import matplotlib.pyplot as plt
    
    EPOCH = 10 # number of epochs the model has trained for
    
    history_dataframe = pd.read_csv(model_directory+'model_history.csv',sep=',')
    
    
    # Plot training & validation loss values
    plt.style.use("ggplot")
    plt.plot(range(1,EPOCH+1),
             history_dataframe['loss'])
    plt.plot(range(1,EPOCH+1),
             history_dataframe['val_loss'],
             linestyle='--')
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Test'], loc='upper left')
    plt.show()