Search code examples
pytorchtorchvision

How does torchvision.transforms.Normalize operate?


I don't understand how the normalization in Pytorch works.

I want to set the mean to 0 and the standard deviation to 1 across all columns in a tensor x of shape (2, 2, 3).

A simple example:

>>> x = torch.tensor([[[ 1.,  2.,  3.],
                       [ 4.,  5.,  6.]],

                       [[ 7.,  8.,  9.],
                        [10., 11., 12.]]])

>>> norm = transforms.Normalize((0, 0), (1, 1))
>>> norm(x)
tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.]],

        [[ 7.,  8.,  9.],
         [10., 11., 12.]]])

So nothing has changed when applying the normalization transform. Why is that?


Solution

  • To give an answer to your question, you've now realized that torchvision.transforms.Normalize doesn't work as you had anticipated. That's because it's not meant to:

    • normalize: (making your data range in [0, 1]) nor

    • standardize: making your data's mean=0 and std=1 (which is what you're looking for.

    The operation performed by T.Normalize is merely a shift-scale transform:

    output[channel] = (input[channel] - mean[channel]) / std[channel]
    

    The parameters names mean and std which seems rather misleading knowing that it is not meant to refer to the desired output statistics but instead any arbitrary values. That's right, if you input mean=0 and std=1, it will give you output = (input - 0) / 1 = input. Hence the result you received where function norm had no effect on your tensor values when you were expecting to get a tensor of mean and variance 0 and 1, respectively.

    However, providing the correct mean and std parameters, i.e. when mean=mean(data) and std=std(data), then you end up calculating the z-score of your data channel by channel, which is what is usually called 'standardization'. So in order to actually get mean=0 and std=1, you first need to compute the mean and standard deviation of your data.

    If you do:

    >>> mean, std = x.mean(), x.std()
    (tensor(6.5000), tensor(3.6056))
    

    It will give you the global average, and global standard deviation respectively.

    Instead, what you want is to measure the 1st and 2nd order statistics per-channel. Therefore, we need to apply torch.mean and torch.std on all dimensions expect dim=1. Both of those functions can receive a tuple of dimensions:

    >>> mean, std = x.mean((0,2)), x.std((0,2))
    (tensor([5., 8.]), tensor([3.4059, 3.4059]))
    

    The above is the correct mean and standard deviation of x measured along each channel. From there you can go ahead and use T.Normalize(mean, std) to correctly transform your data x with the correct shift-scale parameters.

    >>> norm(x)
    tensor([[[-1.5254, -1.2481, -0.9707],
             [-0.6934, -0.4160, -0.1387]],
    
            [[ 0.1387,  0.4160,  0.6934],
             [ 0.9707,  1.2481,  1.5254]]])