Search code examples
pythontensorflowtensorflow-datasetsdata-augmentation

How do I correctly apply data augmentation to a TFRecord Dataset?


I am attempting to apply data augmentation to a TFRecord dataset after it has been parsed. However, when I check the size of the dataset before and after mapping the augmentation function, the sizes are the same. I know the parse function is working and the datasets are correct as I have already used them to train a model. So I have only included code to map the function and count the examples afterward.

Here is the code I am using:

num_ex = 0

def flip_example(image, label):
    flipped_image = flip(image)
    return flipped_image, label


dataset = tf.data.TFRecordDataset('train.TFRecord').map(parse_function)
for x in dataset:
    num_ex += 1

num_ex = 0
dataset = dataset.map(flip_example)

#Size of dataset
for x in dataset:
    num_ex += 1

In both cases, num_ex = 324 instead of the expected 324 for non-augmented and 648 for augmented. I have also successfully tested the flip function so it seems the issue is with how the function interacts with the dataset. How do I correctly implement this augmentation?


Solution

  • When you apply data augmentation with the tf.data API, it is done on-the-fly, meaning that every example is transformed as implemented in your method. Augmenting data this way does not mean that the number of examples in your pipeline changes.

    If you want to use every example n times, simply add dataset = dataset.repeat(count=n). You might want to update your code to use tf.image.random_flip_left_right, otherwise the flip is done the same way each time.