Search code examples
pythontensorflowkerasdata-augmentation

How to set possbility to tf.keras.layers.RandomFlip?


Is there possible to set a possibility when doing random flip operations by using tf.keras.layers.RandomFlip ?

for example:

def augmentation():
        data_augmentation = keras.Sequential([
            keras.layers.RandomFlip("horizontal", p=0.5),
            keras.layers.RandomRotation(0.2, p=0.5)
        ])
   return data_augmentation 

Solution

  • Try creating a simple Lambda layer and defining your probability in a separate function:

    import random
    
    def random_flip_on_probability(image, probability= 0.5):
        if random.random() < probability:
          return tf.image.random_flip_left_right(image)
        return image
    
    def augmentation():
            data_augmentation = keras.Sequential([
                keras.layers.Lambda(random_flip_on_probability),
                keras.layers.RandomRotation(0.2, p=0.5)
            ])
       return data_augmentation 
    

    If you need to use data augmentation during training or inference, you will have to define your own custom layer. Try something like this:

    import tensorflow as tf
    import pathlib
    
    class RandomFlipOnProbability(tf.keras.layers.Layer):
      def __init__(self, probability):
        super(RandomFlipOnProbability, self).__init__()
        self.probability = probability
    
      def call(self, images):
        return tf.cond(tf.random.uniform(()) < self.probability, lambda: tf.image.flip_left_right(images), lambda: images)
    
    dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
    data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
    data_dir = pathlib.Path(data_dir)
    
    batch_size = 32
    
    train_ds = tf.keras.utils.image_dataset_from_directory(
      data_dir,
      validation_split=0.2,
      subset="training",
      seed=123,
      image_size=(180, 180),
      batch_size=batch_size)
    
    
    random_layer = RandomFlipOnProbability(probability = 0.9)
    normalization_layer = tf.keras.layers.Rescaling(1./255)
    
    images, _ = next(iter(train_ds.take(1)))
    images = normalization_layer(random_layer(images))
    image = images[0]
    
    plt.imshow(image.numpy())
    

    enter image description here