Search code examples
tensorflowneural-networkartificial-intelligencetransfer-learning

How transfer learning on EfficientNets work for grayscale images?


My question concerns more about how the algorithm work. I have successfully implemented EfficientNet integration and modelization for grayscale images and now I want to understand why it works.

Here the most important aspect is the grayscale and its 1 channel. When I put channels=1, the algorithm doesn't work because, if I understood right, it was made on 3-channel images. When I put channels=3 it works perfectly.

So my question is, when I put channels = 3 and feed the model with preprocessed images with channels=1, why it continues to work?

Code for EfficientNetB5

# Variable assignments
num_classes = 9
img_height = 84
img_width = 112
channels = 3
batch_size = 32

# Make the input layer
new_input = Input(shape=(img_height, img_width, channels),
                  name='image_input')

# Download and use EfficientNetB5
tmp = tf.keras.applications.EfficientNetB5(include_top=False,
                                           weights='imagenet',
                                           input_tensor=new_input,
                                           pooling='max')
model = Sequential()
model.add(tmp)  # adding EfficientNetB5
model.add(Flatten())
...

Code of preprocessing into grayscale

data_generator = ImageDataGenerator(
        validation_split=0.2)

train_generator = data_generator.flow_from_directory(
        train_path,
        target_size=(img_height, img_width),
        batch_size=batch_size,
        color_mode="grayscale", ###################################
        class_mode="categorical",
        subset="training")

Solution

  • I dug into what happens when you give grayscale images to efficient net models with three-channel inputs. Here are the first layers of Efficient Net B5 whose input_shape is (128,128,3)

     Layer (type)                   Output Shape         Param #     Connected to                     
    ==================================================================================================
     input_7 (InputLayer)           [(None, 128, 128, 3  0           []                               
                                    )]                                                                
                                                                                                      
     rescaling_7 (Rescaling)        (None, 128, 128, 3)  0           ['input_7[0][0]']                
                                                                                                      
     normalization_13 (Normalizatio  (None, 128, 128, 3)  7          ['rescaling_7[0][0]']            
     n)                                                                                               
                                                                                                      
     tf.math.truediv_4 (TFOpLambda)  (None, 128, 128, 3)  0          ['normalization_13[0][0]']       
                                                                                                      
     stem_conv_pad (ZeroPadding2D)  (None, 129, 129, 3)  0           ['tf.math.truediv_4[0][0]']   
    
    

    And here is the shape of the output of each of these layers when the model has as input a grayscale image:

    input_7 (128, 128, 1)
    rescaling_7 (128, 128, 1)
    normalization_13 (128, 128, 3)
    tf.math.truediv_4 (128, 128, 3)
    stem_conv_pad (129, 129, 3)
    

    As you can see, the number of channels of the output tensor switches from 1 to 3 when proceeding to the normalization_13 layer, so let's see what this layer is actually doing. The Normalization layer is performing this operation on the input tensor:

    (input_tensor - self.mean) / sqrt(self.var) // see https://www.tensorflow.org/api_docs/python/tf/keras/layers/Normalization
    

    The number of channels changes after the subtraction. As a matter of fact, self.mean looks like this :

    <tf.Tensor: shape=(1, 1, 1, 3), dtype=float32, numpy=array([[[[0.485, 0.456, 0.406]]]], dtype=float32)>
    

    So self.mean has three channels and when performing the subtraction between a tensor with one channel and a tensor with three channels, the output looks like this: [firstTensor - secondTensorFirstChannel, firstTensor - secondTensorSecondChannel, firstTensor - secondTensorThirdChannel] And this is how the magic happens and this is why the model can take as input grayscale images! I have checked this with efficient net B5 and with efficient net B2V2. Even if they have differences in the way the Normalization layer is declared, the process is the same. I suppose that is also the case for the other efficient net models.

    I hope it was clear enough!