I am new to Keras and still looking for ways for continuous training the model. Since my dataset is very large to store in memory, I am supposed to store in a DB (NoSql DB- MongoDb or HBase) and train records as batch wise. My model LSTM - multi input & outputs. How my current trainings and prediction are as following.
model = Sequential()
model.add(LSTM(64, input_shape=in_dim, activation="relu"))
model.compile(loss="mse", optimizer="adam")
model.fit(xtrain, ytrain, epochs=100, batch_size=12, verbose=0)
ypred = model.predict(xtest)
However, still I am looking for very clear and simple samples that shows how to feed batch wise records pulled from DB to train the model.
If your dataset is very large and cannot be stored in memory, then write a generator which generated a batch of data at a time. You can then use fit_generator
to train on the generator output. If you can code the generator in a way it can be pickled then you can use use_multiprocessing
feature of the fit_generator
to run the generator on multiple processes and keep multiple batch ready which significantly reduces the disk I/O wait time.
import keras
import numpy as np
# Dummy database class
class DB:
def get_total_records_count(self):
return 1e6
def read_records_at(self, ids):
X = np.random.randn(len(ids), 50)
y = np.random.randint(0, 5, len(ids))
return X, y
# Generator which generate a batch at a time
class DataGenerator(keras.utils.Sequence):
def __init__(self, db, batch_size=32):
self.db = db
self.n = self.db.get_total_records_count()
self.idx = np.arange(self.n)
self.batch_size = batch_size
def __len__(self):
return int(np.floor(self.n / self.batch_size))
# Generate a batch of (X, y)
def __getitem__(self, index):
idxs = self.idx[index*self.batch_size:(index+1)*self.batch_size]
return self.db.read_records_at(idxs)
model = keras.models.Sequential()
model.add(keras.layers.Dense(5, input_dim=(50)))
model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy')
df = DataGenerator(DB(), 4)
Epoch 1/1
250000/250000 [==============================] - 380s 2ms/step - loss: 7.1443
<keras.callbacks.callbacks.History at 0x7fa3ff150048>