Search code examples
pythonpytorch

Watching a youtube video on training a CNN classifier,can't understand this line of code related to normalisation


The code can be found here: https://github.com/gaurav67890/Pytorch_Tutorials/blob/master/cnn-scratch-training.ipynb

In code cell number 64

#Transforms
transformer=transforms.Compose([
    transforms.Resize((150,150)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),  #0-255 to 0-1, numpy to tensors
    transforms.Normalize([0.5,0.5,0.5], # 0-1 to [-1,1] , formula (x-mean)/std
                        [0.5,0.5,0.5])
])

why exactly do we need to Normalize the image tensor? also why do we pass parameters [0.5,0.5,0.5],[0.5,0.5,0.5] ?


Solution

  • Normalization does two things. It limits the data range and reduces the skewness. Which usually improves the learning of a ML algorithm. The data range is not necessarily needed in this example since the range is already

    [0,-1]
    

    So the answer to your first question is. We normalize to make the data smoother and bring it into the [-1,1] range.

    The parameters [0.5,0.5,0.5],[0.5,0.5,0.5] are the mean and standard deviation for each of the three channels respectively. You could also calculate them based on the input image but it appears like this step was simplified.