Search code examples
pythonkerasdata-augmentation

How to fit Keras ImageDataGenerator for large data sets using batches


I want to use the Keras ImageDataGenerator for data augmentation. To do so, I have to call the .fit() function on the instantiated ImageDataGenerator object using my training data as parameter as shown below.

image_datagen = ImageDataGenerator(featurewise_center=True, rotation_range=90)
image_datagen.fit(X_train, augment=True)
train_generator = image_datagen.flow_from_directory('data/images')
model.fit_generator(train_generator, steps_per_epoch=2000, epochs=50)

However, my training data set is too large to fit into memory when loaded up at once. Consequently, I would like to fit the generator in several steps using subsets of my training data.

Is there a way to do this?

One potential solution that came to my mind is to load up batches of my training data using a custom generator function and fitting the image generator multiple times in a loop. However, I am not sure whether the fit function of ImageDataGenerator can be used in this way as it might reset on each fitting approach.

As an example of how it might work:

def custom_train_generator():
    # Code loading training data subsets X_batch
    yield X_batch


image_datagen = ImageDataGenerator(featurewise_center=True, rotation_range=90)
gen = custom_train_generator()

for batch in gen:
    image_datagen.fit(batch, augment=True)

train_generator = image_datagen.flow_from_directory('data/images')
model.fit_generator(train_generator, steps_per_epoch=2000, epochs=50)

Solution

  • NEWER TF VERSIONS (>=2.5):

    ImageDataGenerator() has been deprecated in favour of :

    tf.keras.utils.image_dataset_from_directory

    An example usage from the documentation:

      tf.keras.utils.image_dataset_from_directory(
        directory,
        labels='inferred',
        label_mode='int',
        class_names=None,
        color_mode='rgb',
        batch_size=32,
        image_size=(256, 256),
        shuffle=True,
        seed=None,
        validation_split=None,
        subset=None,
        interpolation='bilinear',
        follow_links=False,
        crop_to_aspect_ratio=False,
        **kwargs
    )
    

    OLDER TF VERSIONS (<2.5)

    ImageDataGenerator() provides you with the possibility of loading the data into batches; You can actually use in your fit_generator() method the parameter batch_size, which works with ImageDataGenerator(); there is no need (only for good practice if you want) to write a generator from scratch.

    IMPORTANT NOTE:

    Starting from TensorFlow 2.1, .fit_generator() has been deprecated and you should use .fit()

    Example taken from Keras official documentation:

    datagen = ImageDataGenerator(
        featurewise_center=True,
        featurewise_std_normalization=True,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True)
    
    # compute quantities required for featurewise normalization
    # (std, mean, and principal components if ZCA whitening is applied)
    datagen.fit(x_train)
    
    # TF <= 2.0
    # fits the model on batches with real-time data augmentation:
    model.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
                        steps_per_epoch=len(x_train) // 32, epochs=epochs)
    
    #TF >= 2.1
    model.fit(datagen.flow(x_train, y_train, batch_size=32),
             steps_per_epoch=len(x_train) // 32, epochs=epochs)
    

    I would suggest reading this excellent article about ImageDataGenenerator and Augmentation: https://machinelearningmastery.com/how-to-configure-image-data-augmentation-when-training-deep-learning-neural-networks/

    The solution to your problem lies in this line of code(either simple flow or flow_from_directory):

    # prepare iterator
    it = datagen.flow(samples, batch_size=1)
    

    For creating your own DataGenerator, one should have a look at this link(for a starting point): https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly

    IMPORTANT NOTE (2):

    If you use Keras from Tensorflow (Keras inside Tensorflow), then for both the code presented and the tutorials you consult, ensure that you replace the import/neural network creation snippets:

    from keras.x.y.z import A
    

    WITH

    from tensorflow.keras.x.y.z import A