Search code examples
kerastensorflow2.0tensorflow2.x

Convert model.fit_generator to model.fit


I have codes in the following,

train_datagen = ImageDataGenerator(
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
        'data/train',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
        'data/validation',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary')

Now model.fit_generator is defined as following:

model.fit_generator(
        train_generator,
        steps_per_epoch=2000,
        epochs=50,
        validation_data=validation_generator,
        validation_steps=800)

Now model.fit_generator is deprecated, what is the proper way to change model.fit_generator to model.fit in this case?


Solution

  • You just have to change model.fit_generator() to model.fit().

    As of TensorFlow 2.1, model.fit() also accepts generators as input. As simple as that.

    From TensorFlow's official documentation:

    Warning: THIS FUNCTION IS DEPRECATED. It will be removed in a future version. Instructions for updating: Please use Model.fit, which supports generators.