Search code examples
pythontensorflowkerasgenerator

Tensorflow dataset from generator OutOfRangeError: End of sequence


Function that returns a generator:

def img_x_gen(dir):
    files = glob.glob(f'{dir}/*.jpg')
    for file in files:
        X_i = np.asarray(Image.open(file))
        X_i = X_i / 255.0
        yield X_i, X_i  # It's an autoencoder so return X_i twice

Creating the datasets:

types = (tf.float32, tf.float32)
shapes = (img_shape, img_shape)
ds_train = tf.data.Dataset.from_generator(img_x_gen, types, shapes,
    args=['train/img_sq']).batch(batch_size)
ds_valid = tf.data.Dataset.from_generator(img_x_gen, types, shapes,
    args=['valid/img_sq'],).batch(batch_size)

Calling fit method:

vae.fit(ds_train, epochs=3, validation_data=ds_valid, verbose=True) 

I get the error in the question title:

OutOfRangeError:  End of sequence

The number of examples in the training set is 894, in the validation set it's 247. batch_size is 32. I know the model works if I load data into memory.

I've also tried making a generator and manually batching (and passing steps_per_epoch and validation_steps to the model.fit method), but that runs into a similar error: Your input ran out of data; interrupting training.

So clearly I don't understand something about generators.


Solution

  • You can directly use model.fit_generator() with generators instead of model.fit(). You are receiving this error because your generator does not yield require number of values as per shape of your input . You can quick fix it by making it an infinite generator.

    def img_x_gen(dir):
        while True:
            # Make your generator infinite
            files = glob.glob(f'{dir}/*.jpg')
            for file in files:
                X_i = np.asarray(Image.open(file))
                X_i = X_i / 255.0
                yield X_i, X_i 
    

    and then modify the fit function in this way:

    vae.fit(ds_train, epochs = 3, validation_data = ds_valid, verbose=True, steps_per_epoch = <TRAIN LENGTH>//batch_size,  validation_steps = <VAL LENGTH>//batch_size)