Search code examples
tensorflowmachine-learningkerasdeep-learningconv-neural-network

How to perform same data augmentation on 2 different image inputs?


I want to construct a multi-path CNN with 2 image inputs that each enter their own CNN and then the features are concatenated at the end. Right now I have something like this:

features & labels for 1st set of images: X (has features) y (labels)

features & labels for 2nd set of images (segmented): X_2 (has features) y_2 (has labels)

my code for the augmentation is this:

data_augmentation = tf.keras.Sequential()
data_augmentation.add(tf.keras.layers.RandomFlip("horizontal_and_vertical"))
data_augmentation.add(tf.keras.layers.RandomRotation(0.2))
data_augmentation.add(tf.keras.layers.RandomZoom(height_factor=(.05),
                                                 width_factor=(.05)))

and I want the CNN to look something like this: CNN model

both the 1st and 2nd set of images are the same size and are in the same order. The only difference is that the 2nd set are segmented. I want to apply the same data augmentation for both of them. Does anyone know how to do this?

I tried looking for some answers online, but most of the website I found that talk about multi input CNNS didn't use an augmentation layer so I'm stuck.


Solution

  • (See below for original answer)

    I guess this Keras documentation is relevant to data augmentation in Keras. In particular, it mentions two options to use data augmentation layers: using it inside Model and using it in the tf.data pipeline.

    Here is a rough (think of it as pseudo-code) example implementation using the second option.

    Additionally, my personal preference would be to implement the Model described in the question with a custom Model for flexbility, but perhaps there are ways to do the same with higher-level Keras API.

    In the following snippet, a custom Model is set up and then the training dataset pipeline. It lacks the exact loading steps but should give a rough idea how to connect the dataset with a custom Model.

    import tensorflow as tf
    
    
    class TwinCNN(tf.keras.Model):
        def __init__(self, name="TwinCNN"):
            super().__init__(name=name)
            
            self.cnn_for_color = [
                tf.keras.layers.Conv2D(128, 3, padding="same"), 
                tf.keras.layers.Conv2D(128, 3, padding="same"), 
                tf.keras.layers.Conv2D(128, 3, padding="same"), 
            ]  # Some layers
            
            self.cnn_for_segment = [
                tf.keras.layers.Conv2D(64, 3, padding="same"), 
                tf.keras.layers.Conv2D(64, 3, padding="same"), 
                tf.keras.layers.Conv2D(64, 3, padding="same"), 
            ]
            
            self.classifier = [
                tf.keras.layers.Flatten(),
                tf.keras.layers.Dense(
                    units=4,  # seems there are 4 classes
                    activation=tf.nn.softmax)
            ]
        
        def call(self, inputs, training):
            color_image = inputs['color']
            segment_image = inputs['segment']
            
            color_fts = color_image
            for ly in self.cnn_for_color:
                color_fts = ly(color_fts, training=training)
            
            segment_fts = segment_image 
            for ly in self.cnn_for_segment:
                segment_fts = ly(segment_fts, training=training)
            
            concat_fts = tf.concat([color_fts, segment_fts], axis=-1, name='concat_fts')
            
            net = concat_fts
            for ly in self.classifier:
                net = ly(net, training=training)
            return net  # shape Nx4
    
    
    # Your data augmentation layers
    data_augmentation = tf.keras.Sequential()
    data_augmentation.add(tf.keras.layers.RandomFlip("horizontal_and_vertical"))
    data_augmentation.add(tf.keras.layers.RandomRotation(0.2))
    data_augmentation.add(tf.keras.layers.RandomZoom(height_factor=(.05),
                                                     width_factor=(.05)))
    
    def apply_augment_to_zipped(color_img, segment_img):
        c_color = color_img.shape[-1]  # Number of channels
        concat_img = tf.concat([color_img, segment_img], axis=-1)
        # **For your test/validation dataset, don't use augmentation!**
        img_aug = data_augmentation(concat_img, training=True)
        color_aug = img_aug[..., :c_color]
        segment_aug = img_aug[..., c_color:]
        return color_aug, segment_aug
    
    
    def turn_image_tuple_to_dict(color_img, segment_img):
        # Put the two images into a dict
        return dict(color=color_img, segment=segment_img)
    
    
    # Assuming ds_color and ds_segment has the *same* label, so we can take either one
    ds_color_images = tf.data.Dataset(...)  # Set up a dataset for color images
    ds_labels = tf.data.Dataset(...)  # Set up a dataset for labels
    ds_segment_images = tf.data.Dataset(...)  # Set up a dataset for segment images
    
    ds_images = tf.data.Dataset.zip((ds_color_images, ds_segment_images))  # each entry will be a tuple of 2 images
    ds_images_augmented = ds_images.map(apply_augment_to_zipped)
    ds_images_augmented = ds_images_augmented.map(turn_image_tuple_to_dict)
    ds_train = tf.data.Dataset.zip((ds_images_augmented, ds_labels))  # The training set 
    
    # NOTE: for a validation dataset, we should skip the augmentation (i.e. don't call apply_augment_to_zipped)
    
    model = TwinCNN()
    # TODO: select an optimizer, 
    # TODO: model.compile(optimizer)
    model.fit(
        ds_train, # TODO: other training settings
    )
    
    
    

    Original answer

    Since the two datasets are of the same size, one can use

    ds = tf.data.Dataset.zip((ds1, ds2))
    

    to create a dataset where each entry is a pair of images say (x, x2). Then one could do something like

    ds = ds.map(lambda x, x2: tf.concat([x, x2], axis=-1))
    

    To merge x and x2 into a single tensor which fits better for the augmentation layers.

    Assuming the number of channels is constant for each x (say 3 channels) and each x2 (say 1 channel), the augmented result can be sliced out afterwards:

    x_aug = input_tensor[..., :3]
    x2_aug = input_tensor[..., 3:4]
    

    As a side note, it would be better if the question is more detailed. For example, it would be helpful to provide some code or pseudo-code to explain how exactly the data should flow in the "multi-path CNN" and what sort of data augmentation is being attempted.