Search code examples
pythontensorflowimage-processingkerasdata-augmentation

How can I apply the same augmentation to a batch of images?


I have a dataset of videos. Since the dataset is small, I am trying to augment the video data. I have not found any resources on augmenting videos, so what I think will work is -

  1. Extract required frames from the video
  2. Apply data augmentation to the extracted frames

Now, let's say I have extracted 20 frames from a single video. In order for my data to make sense, I will have to apply the same augmentation to these 20 frames. How can I achieve that? I am also open to other libraries if it makes the work easy.

I am guessing some changes to the ImageDataGenerator.flow_from_directory(...) arguments will do the trick. Here's the code snippet from Keras documentation.

ImageDataGenerator.flow_from_directory(
    directory,
    target_size=(256, 256),
    color_mode="rgb",
    classes=None,
    class_mode="categorical",
    batch_size=32,
    shuffle=True,
    seed=None,
    save_to_dir=None,
    save_prefix="",
    save_format="png",
    follow_links=False,
    subset=None,
    interpolation="nearest",
)

Thank you in advance!


Solution

  • You can use a tf.data.Dataset, and apply transformations after the batching operation. This will require some work to make your own directory iterator (something like this), but here's the essence of it:

    import tensorflow as tf
    import matplotlib.pyplot as plt
    from skimage import data
    
    cats = tf.concat([data.chelsea()[None, ...] for i in range(24)], axis=0)
    
    test = tf.data.Dataset.from_tensor_slices(cats)
    
    
    def augment(tensor):
        tensor = tf.cast(x=tensor, dtype=tf.float32)
        tensor = tf.divide(x=tensor, y=tf.constant(255.))
        tensor = tf.image.random_hue(image=tensor, max_delta=5e-1)
        tensor = tf.image.random_brightness(image=tensor, max_delta=2e-1)
        return tensor
    
    
    test = test.batch(8).map(lambda x: augment(x))
    
    
    fig = plt.figure()
    plt.subplots_adjust(wspace=.1, hspace=.2)
    images = next(iter(test))
    for index, image in enumerate(images):
        ax = plt.subplot(4, 2, index + 1)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.imshow(tf.clip_by_value(image, clip_value_min=0, clip_value_max=1))
    plt.show()
    

    enter image description here

    Not that for some reason, this doesn't work for tf.image.random_flip_left_right.