Search code examples
pythontensorflowkeraskeras-layer

Defining a Keras Custom Layer that adds a random value to a flatten layer output


How to define a Keras Custom Layer to add a random value to the output of a Flatten layer (of a CNN) of size (None, 100)?


Solution

  • TL;DR:

    class Noise(keras.layers.Layer):
        def __init__(self, mean=0, stddev=1.0, *args, **kwargs):
            super(Noise, self).__init__(*args, **kwargs)
            self.mean = mean
            self.stddev = stddev
    
        def call(self, inputs, 
                 training=False # Only add noise in training!
                 ):
            if training:
                return inputs + tf.random.normal(
                    inputs.shape, 
                    mean=self.mean,
                    stddev=self.stddev
                ) # Add random noise during training
            else:
                return inputs + tf.fill(
                    inputs.shape, 
                    self.mean
                ) # Add mean of random noise during inference
    
    model = keras.Sequential([
        layers.Flatten(input_shape=(10,10,1)),
        Noise(stddev=.1)
    ])
    
    model(input_batch,
          training=True # Defaults to False. 
                        # Noise is added only in training mode.
    ) 
    

    Full example.

    There is also a built-in keras.layers.GaussianNoise layer that does same exact thing as my Noise above.

    Several notes to bear in mind when implementing above code:

    • If you intend to use random noise as a regularizer to fight overfitting, it is much, much better to use keras' built-in image augmentation module.
    • Avoid using non-normal distributions when dealing with CNNs. Using uniform distribution, for example, will shift mean value of a batch, negating all image normalization that CNNs so desperately need.
    • Consider using dropout if the results of flatten is fed into dense classifiers on the top. Dropout is much more efficient in what you are probably trying to do.

    For any clarification, please don't hesitate to comment! Cheers.