Search code examples
tensorflowkerastensorflow2.0tensorbatch-normalization

How to specify linear transformation in tensorflow?


I want to perform a simple linear transformation on a layer x, so the output of the transformation is y = a*x + b. I am working with images, so x is 3-dimensional (height * width * channels). Then a is a scale vector of size c, where c is the number of channels, and it has a single scale parameter for each channel dimension of x. Similarly, b is a shift vector of size and it has a single shift parameter for each channel dimension of x. This is a simple variation of normalization without normalizing the batch statistics.

Here is an example:

# TODO: learn gamma and beta parameters
x = tf.keras.layers.Conv2D(filters=num_filters_e1,
    kernel_size=5,
    strides=2,
    padding='same')(input)
x = tf.keras.layers.Multiply()([x, gamma]) # scale by gamma along channel dim
x = tf.keras.layers.Add()([x, beta]) # shift with beta along channel dim
y = tf.keras.layers.ReLU()(x) # apply activation after transformation

I'm not sure about how to obtain gamma and beta. These are supposed to be parameters learned by the model during training, but I'm not sure how to construct or specify them. Typically I just specify layers (either convolutional or dense) to learn weights, but I'm not sure what layer to use here and what the layer should take in as input. Do I have to somehow initialize a vector of ones and then learn the weights to transform those into gamma and beta?

Even if it is possible to do this with TensorFlow's batchnorm layer (which is still useful to know), I would like to learn how to implement this scaling/shifting from scratch. Thank you!


Solution

  • As mentioned in the comments, one can accomplish this with a custom Keras layer (see the tf.keras tutorial on this). One can subclass the layer base class and implement the behavior of the transform. In the example below, the layer contains learnable weights gamma and beta, both of which have shape (num_channels,). The weights are initialized with ones and zeros, respectively.

    import tensorflow as tf
    
    class LinearTransform(tf.keras.layers.Layer):
        """Layer that implements y=m*x+b, where m and b are
        learnable parameters.
        """
        def __init__(
            self,
            gamma_initializer="ones",
            beta_initializer="zeros",
            dtype=None,
            **kwargs
        ):
            super().__init__(dtype=dtype, **kwargs)
            self.gamma_initializer = gamma_initializer
            self.beta_initializer = beta_initializer
    
        def build(self, input_shape):
            num_channels = int(input_shape[-1])
            self.gamma = self.add_weight(
                "gamma",
                shape=[num_channels],
                initializer=self.gamma_initializer,
                dtype=self.dtype,
            )
            self.beta = self.add_weight(
                "beta",
                shape=[num_channels],
                initializer=self.beta_initializer,
                dtype=self.dtype,
            )
    
        def call(self, inputs):
            return self.gamma * inputs + self.beta
    

    And here are tests of the behavior:

    tf.random.set_seed(42)
    inputs = tf.random.normal([1, 24, 24, 4], dtype="float32", seed=42)
    
    layer = LinearTransform()
    np.testing.assert_allclose(layer(inputs), inputs)
    
    layer = LinearTransform(
        gamma_initializer=tf.keras.initializers.constant(4),
        beta_initializer=tf.keras.initializers.constant(1),
    )
    np.testing.assert_allclose(layer(inputs), inputs * 4 + 1)