Search code examples
tensorflowconv-neural-networkimage-segmentationautoencoderunet-neural-network

Image Segmentation Tensorflow tutorials


In this tf tutorial, the U-net model has been divided into 2 parts, first contraction where they have used Mobilenet and it is not trainable. In second part, I'm not able to understand what all layers are being trained. As far as I could see, only the last layer conv2dTranspose seems trainable. Am I right?

And if I am how could only one layer is able to do such a complex task as segmentation?

Tutorial link: https://www.tensorflow.org/tutorials/images/segmentation


Solution

  • The code for the Image Segmentation Model, from the Tutorial is shown below:

    def unet_model(output_channels):
      inputs = tf.keras.layers.Input(shape=[128, 128, 3])
      x = inputs
    
      # Downsampling through the model
      skips = down_stack(x)
      x = skips[-1]
      skips = reversed(skips[:-1])
    
      # Upsampling and establishing the skip connections
      for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])
    
      # This is the last layer of the model
      last = tf.keras.layers.Conv2DTranspose(
          output_channels, 3, strides=2,
          padding='same')  #64x64 -> 128x128
    
      x = last(x)
    
      return tf.keras.Model(inputs=inputs, outputs=x)
    

    First part of the Model is Downsampling uses not the entire Mobilenet Architecture but only the Layers,

    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project'
    

    of the Pre-Trained Model, Mobilenet, which are non-trainable.

    Second part of the Model (which is of your interest), before the layer, Conv2DTranspose is Upsampling part, which is present in the list,

    up_stack = [
        pix2pix.upsample(512, 3),  # 4x4 -> 8x8
        pix2pix.upsample(256, 3),  # 8x8 -> 16x16
        pix2pix.upsample(128, 3),  # 16x16 -> 32x32
        pix2pix.upsample(64, 3),   # 32x32 -> 64x64
    ]
    

    It means that it is accessing a Function named upsample from the Module, pix2pix. The code for the Module, pix2pix is present in this Github Link.

    Code for the function, upsample is shown below:

    def upsample(filters, size, norm_type='batchnorm', apply_dropout=False):
      """Upsamples an input.
      Conv2DTranspose => Batchnorm => Dropout => Relu
      Args:
        filters: number of filters
        size: filter size
        norm_type: Normalization type; either 'batchnorm' or 'instancenorm'.
        apply_dropout: If True, adds the dropout layer
      Returns:
        Upsample Sequential Model
      """
    
      initializer = tf.random_normal_initializer(0., 0.02)
    
      result = tf.keras.Sequential()
      result.add(
          tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                          padding='same',
                                          kernel_initializer=initializer,
                                          use_bias=False))
    
      if norm_type.lower() == 'batchnorm':
        result.add(tf.keras.layers.BatchNormalization())
      elif norm_type.lower() == 'instancenorm':
        result.add(InstanceNormalization())
    
      if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))
    
      result.add(tf.keras.layers.ReLU())
    
      return result
    

    This means that the second part of the Model comprises of the Upsampling Layers, whose functionality is defined above, with the Number of Filters being 512, 256, 128 and 64.