Search code examples
tensorflowdata-augmentationdata-preprocessingimage-masking

Preprocessing layers with seed not producing the same data augmentation for images and masks


I'm trying to create a simple preprocessing augmentation layer, following this Tensorflow tutorial. I created this 'simple' example that shows the problem I'm having.

Even though I'm initializing the augmentation class with a seed, operations applied to the images, and the corresponding masks are not always equal.

What am I doing wrong?

Note: tf v2.10.0

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import skimage
import rasterio as rio

def normalize(array: np.ndarray):
    """ normalise image to give a meaningful output """
    array_min, array_max = array.min(), array.max()
    return (array - array_min) / (array_max - array_min)

# field
im = rio.open('penguins.tif')
fields = np.zeros((1,im.shape[0],im.shape[1],3))
fields[0,:,:,0] = normalize(im.read(1))
fields[0,:,:,1] = normalize(im.read(2))
fields[0,:,:,2] = normalize(im.read(3))

# mask is a simple contour
masks = skimage.color.rgb2gray(skimage.filters.sobel(fields[0]))
masks = np.expand_dims(masks, [0,3])

In this case, the dataset is only one image, we can use this function to visualize the field and the mask.

def show(field:np.ndarray, mask:np.ndarray): 
    """Show the field and corresponding mask."""
    fig = plt.figure(figsize=(8,6))
    ax1 = fig.add_subplot(121)
    ax2 = fig.add_subplot(122)
    ax1.imshow(field[:,:,:3])
    ax2.imshow(mask,cmap='binary')    
    plt.tight_layout()
    plt.show()

show(fields[0], masks[0])

image and field

Alright, now I used the example from the tutorial that will randomly flip (horizontal) the image and the mask.

class Augment(tf.keras.layers.Layer):
    def __init__(self, seed=42):
        super().__init__()
        # both use the same seed, so they'll make the same random changes.
        self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
        self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)

    def call(self, inputs, labels):
        inputs = self.augment_inputs(inputs)
        labels = self.augment_labels(labels)
        return inputs, labels

Now if I run the following multiple times, I will eventually get opposite flip on the field and mask.

# Create a tf.datasets
ds = tf.data.Dataset.from_tensor_slices((fields, masks))

ds2 = ds.map(Augment())

for f,m in ds2.take(1):
    show(f, m)

different processing to image and mask

I would expect the image and its mask to be flip the same way since I set the seed in the Augment class as suggested in the Tensorflow tutorial.


Solution

  • Augmentation can be done on the concatenated image and mask along the channel axis to form a single array and then recover the image and label back, which is shown below:

    class Augment(tf.keras.layers.Layer):
        def __init__(self):
            super().__init__()
            # both use the same seed, so they'll make the same random changes.
            self.augment_inputs = tf.keras.layers.RandomRotation(0.3)
    
    
        def call(self, inputs, labels):
            
            output = self.augment_inputs(tf.concat([inputs, labels], -1) )
            
            inputs = output[:,:,0:4]
            labels = output[:,:,4:]
    
            return inputs, labels