Search code examples
tensorflowdataset

Generating Crops On tf.data.Dataset


so i want to do data augmentation on tf.data.Dataset and i want to generate 10-20 crops of those images and put them back inside of tf.data.Dataset.

transformation of dataset should look like this:

 [(image1; label), (image2; label)] -> [(crop1image1; label), (crop2image1;,label), ... (crop9image2; label), (crop10imag2; label)]

How is this possible to do?


Solution

  • flat_map method of tf.data.Dataset object can be used to do so:

    import tensorflow as tf
    
    # Function to generate multiple crops for an image
    def generate_crops(image, label):
        crops = []
    
        # Define the number of crops to generate for each image
        num_crops = tf.random.uniform([], minval=10, maxval=20, dtype=tf.int32)
    
        for _ in range(num_crops):
            # Randomly crop the image
            crop = tf.image.random_crop(image, size=[crop_height, crop_width, num_channels])
            
            crops.append((crop, label))
    
        return crops
    
    # Define the size of the crops and number of channels
    crop_height = 32
    crop_width = 32
    num_channels = 3
    
    # Load your dataset here (e.g., using tf.data.Dataset.from_tensor_slices())
    
    # Apply data augmentation and generate multiple crops for each image
    dataset = dataset.flat_map(generate_crops)
    
    # Shuffle and batch the dataset as needed
    dataset = dataset.shuffle(buffer_size=1000).batch(batch_size)
    
    for batch in dataset:
        # Perform training steps here
        pass