Search code examples
pythontensorflowdata-augmentation

How to make the TensorFlow map() function return multiple values?


I am trying to write a function that will augment images from a dataset. I am able to successfully augment an existing image and return it, but I want to be able to do multiple augmentations on a single image and return those augmented images individually and then add them to the original dataset.

Augmentation function:

def augment_data(image, label):

h_flipped_image = tf.image.flip_left_right(image)
v_flipped_image = tf.image.flip_up_down(image)

return h_flipped_image, label

Map function:

train_ds = train_ds.map(augment_data)

train_ds is tf.data Dataset with with the following shape:

<PrefetchDataset shapes: ((None, 224, 224, 3), (None, 238)), types: (tf.float32, tf.bool)>

How can I make the map function return multiple value in such a way that I could, for example, return both the h_flipped_image and the v_flipped_image and add them to the train_ds dataset?


Solution

  • It turns out I was not looking at the problem from the right direction. I realized that I do not need the augmented samples to be included in the dataset after all. Instead I chose to opt for augmenting the data during the training process.

    Since the training process will take multiple epochs I can just augment the image right before the network needs it. I did this by modifying my augment_data function, so it now has a random chance to perform a certain augmentation. Each epoch a random combination of augmentations will be performed on the image resulting in a different input image for the network each time.

    def augment_data(image, label):
    
    rand = tf.random.uniform([])
    if(rand > 0.5):
        image = tf.image.flip_up_down(image)
    
    # Additional augmented techniques should be defined here
    
    return image, label
    

    Make sure you use a TensorFlow function for generating the random number. A simple random.random() will not work due to the way TensorFlow interprets python code.