Search code examples
keraslarge-databatching

Is it logical to loop on model.fit in Keras?


Is it logical to do as below in Keras in order not to run out of memory?

for path in ['xaa', 'xab', 'xac', 'xad']:
    x_train, y_train = prepare_data(path)
    model.fit(x_train, y_train, batch_size=50, epochs=20, shuffle=True)

model.save('model')

Solution

  • It is, but prefer model.train_on_batch if each iteration is generating a single batch. This eliminates some overhead that comes with fit.

    You can also try to create a generator and use model.fit_generator():

    def dataGenerator(pathes, batch_size):
    
        while True: #generators for keras must be infinite
            for path in pathes:
                x_train, y_train = prepare_data(path)
    
                totalSamps = x_train.shape[0]
                batches = totalSamps // batch_size
    
                if totalSamps % batch_size > 0:
                    batches+=1
    
                for batch in range(batches):
                    section = slice(batch*batch_size,(batch+1)*batch_size)
                    yield (x_train[section], y_train[section])
    

    Create and use:

    gen = dataGenerator(['xaa', 'xab', 'xac', 'xad'], 50)
    model.fit_generator(gen,
                        steps_per_epoch = expectedTotalNumberOfYieldsForOneEpoch
                        epochs = epochs)