Search code examples
tensorflowkerastime-serieslstmrecurrent-neural-network

How to update/re-train an LSTM model when a new data arrives?


In a real-time application of Longe-Short Term Memory (LSTM) networks or Recurrent Neural Networks (RNN), the data would arrive (being fed) in real-time. I don't know how to introduce the newly arrived data to the pre-trained network in those cases.

Imagine I trained an LSTM for a 1d input data like [0,1,2,3,4,5,6,7] with a time window (sequence) of 5. So the input dataset looks like this:

X            Y
0-1-2-3-4    5
1-2-3-4-5    6
2-3-4-5-6    7

I predicted the value for the 8th timestep, saved the trained network and now I received the actual data for the 8th timestep so I want the network to know about it before aiming for predicting step 9. So I would create a new sequence:

X_train    y_train
3-4-5-6-7  8

Now the question is how I should re-train my loaded network using this sequence. Considering the following template to fit the new data.

model.fit(X_train, y_train, validation_data=(X_test, y_test), 
      epochs=epochs, batch_size=batch_size, verbose=1, shuffle=False)

I would appreciate it if you could explain these detailed questions as well:

  • Apparently, there won't be any validation/test set since there is only one sequence to be trained. Is that fine?
  • If so how we can prevent the model to overfit for the latest data?
  • Would it be a good idea to grab a couple of previous sequences and create a new train/test set out of them including the new sequence?
  • What would be a good number of epochs to avoid under/over-fit (without having a validation set)?

Solution

  • In general you definitely don't want to re-train the network as data comes in for the reason you already hinted at: you have no validation set therefore no real way to verify the network outperforms your previous model and it could fail quite badly. Also the computational cost of real-time training could grow large. You in general will want to collect a "reasonable" sized dataset of new data before re-training your model (reasonable will vary but I'd suggest at least 10-15% of the size of your original training dataset).

    You will then want to re-train (will describe in more detail in a second) and test the model both on segments of the original validation/test dataset and the newly collected one. You will want to compare the performance of the model to previously trained models before deploying anything new.

    If so how we can prevent the model to overfit for the latest data?

    This gets into a topic called catastrophic forgetting and life long learning. The topic is too dense to describe in any detail here but the basic problem is that when you train on new data the NN loses its prior generalization ability. For a time series model that could be a good thing in certain circumstances (your distribution fundamentally and permanently shifts) but is usually a bad thing as your model may lose knowledge of complex seasonal patterns that will come back. For this reason it is usually recommended that you append your new training data to the old dataset and train a model on the full dataset (though this obviously expensive computationally). Or you could get into things like Elastic Weight Consolidation but those are tricky to throw onto deep time series models currently

    In any case you need to think carefully about your model deployment life-cycle and your underlying data/model infrastructure (how do you store data, how do you track your models, and how much compute power do you have). Then maybe check out more information on catastrophic forgetting and how to over come-it. This paper might help, however keep in mind for many scenarios re-training from scratch might be the best option.