Search code examples
pythontensorflowkerasgenerator

Tensorflow Keras model.fit using generator function and steps_per_epoch


I have this model.fit call:

transformer.fit(
   x=data_generation.generate_dataset(batch_size, dontchange, train_indices, filenames),
   epochs=epochs,
   steps_per_epoch=len(train_indices),
   validation_data=data_generation.generate_dataset(batch_size, dontchange, val_indices, filenames),
   validation_steps=len(val_indices)
)

And then this is part of my generate_dataset function:

def generate_dataset(batch_size_in, dontChange, index_list_in, filenames):
    epoch = 0
    while True:
        epoch = epoch + 1
        raw_dataset = tf.data.TFRecordDataset(filenames)
        batch_size = batch_size_in
        index_list = []
        index_list = index_list_in
        cx = 1
        for index in index_list:
            tf.print("Epoch: {}, Batch: {}, Batches total: {}".format(epoch, cx, len(index_list_in)), summarize=32 *
            10 * 250, output_stream="file://logtest.txt")
            cx = cx + 1
            how_much_to_take = batch_size
            batch_of_records = raw_dataset.skip(index).take(how_much_to_take)
            max_number_of_tv_in_batch = 0
            batch = []

    [...]

            yield (tf.stack(input_batch), tf.stack(attention_mask_batch), tf.stack(padding_mask_batch)), tf.stack(output_batch)


so for each index in index list its yielding a batch so steps_per_epoch and validation_steps are equal to the number of batches generated with one loop through the respective index_list_in in the val generator and the train generator. So it should all work smoothly.

As you can I see I print the progress to a log file using tf.print and I noticed something I cant really explain.

Epoch: 1, Batch: 1, Batches total: 607
Epoch: 1, Batch: 2, Batches total: 607
Epoch: 1, Batch: 3, Batches total: 607
Epoch: 1, Batch: 4, Batches total: 607
Epoch: 1, Batch: 5, Batches total: 607
Epoch: 1, Batch: 6, Batches total: 607
Epoch: 1, Batch: 7, Batches total: 607 
[...]
Epoch: 1, Batch: 605, Batches total: 607
Epoch: 1, Batch: 606, Batches total: 607
Epoch: 1, Batch: 607, Batches total: 607
Epoch: 2, Batch: 1, Batches total: 607
Epoch: 1, Batch: 1, Batches total: 67
Epoch: 1, Batch: 2, Batches total: 67
Epoch: 1, Batch: 3, Batches total: 67
[...]
Epoch: 1, Batch: 66, Batches total: 67
Epoch: 1, Batch: 67, Batches total: 67
Epoch: 2, Batch: 1, Batches total: 67
Epoch: 2, Batch: 2, Batches total: 607
Epoch: 2, Batch: 3, Batches total: 607

So it basically loads the 607 batches as it should but then it loads an aditional batch as you can tell by this output:

Epoch: 2, Batch: 1, Batches total: 607

Just before it goes into the generator for the val batches. And then for the Val Generator its the same it works perfectly fine it yields 67 batches as it should but then it goes into the next loop and for some reason I see this:

Epoch: 2, Batch: 1, Batches total: 67

Then it starts here for the second epoch in the train_generator:

Epoch: 2, Batch: 2, Batches total: 607

instead of:

Epoch: 2, Batch: 1, Batches total: 607

as it should.

And then Epoch 3 same thing:

Epoch: 3, Batch: 2, Batches total: 607

instead of:

Epoch: 3, Batch: 1, Batches total: 607

But why is it doing this? As I said the steps_per_epoch match exactly the batches yielded by one loop through the index_list_in. Which is obvious since I use len(train_indices)/len(val_indices) to determine the number of batches. But it appears to need one more batch per epoch. Why?

What I want is that it loads only and exactly all 607 batches per epoch and it should start with the same batch every epoch ofc. What do I have to change? Is this a bug of Keras or am I doing something wrong?


Solution

  • Actually I think figured it out.

    I need to add the first batch twice at the beginning of the train generator in the first epoch. (All other batches only steps_per_epoch batches).

    And I need to add the first batch twice at the beginning of the val generator for each epoch.

    Because for those batches actually no accuracy and loss is calculated (the model doesnt train on those batches). I think the model just needs those batches for context.

    It needs the context batch in val_generator every epoch because it actually expects one full dataset to repeat every epoch while for the train_dataset it wants to go through the data_set in epoch steps so it just needs the context once.

    I didnt find anything in the documentation about this though so its either a bug or hard to find.