Search code examples
pythonpython-3.xpytorchtorchbatch-normalization

Pytorch batchnorm2d: "RuntimeError: running_mean should contain 1 elements not 64"


None of the similar questions worked. So please do not flag as dublicate

Pytorch BatchNorm2d expects an input in the format N C H W where

N = Batchsize
C = Channels
H = Height
W = Width

as they indicate in the docs: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html

If we test this using a random Tensor we get an error:

import torch
        
n = 32 # N = Batch size
c = 1 # C = Channels
h = 64 # H = Height
w = 512 # W = Width
        
torch.nn.BatchNorm2d(h)(torch.rand(n,c,h,w))

enter image description here

The following code "works", but has the input format "NHWC"

import torch

n = 32 # N = Batch size
c = 1 # C = Channels
h = 64 # H = Height
w = 512 # W = Width

x = torch.rand(n,h,w,c)

x = torch.nn.BatchNorm2d(h)(x)

Solution

  • The thing here is, if you are changing the values of N, C, H, or W variables, you are actually not changing the internal memory format the PyTorch developers have set; that's just a variable name, i.e., if you provide input in (n,h,c,w) as above, internally, N->N, H->C (H will be the number of channels, instead heights as you are thinking), C->H, and W->W.

    Returning to the question, the number of channels in your input data should match the number of channels in nn.BatchNorm2d.

    In your case, number of channels you set is one, but BatchNorm is expecting 64 channels from the user. To fix this, you can follow these examples:

    Example:

    import torch
    n, c, h, w = 32, 64, 64, 512
    x = torch.rand(n,c,h,w)
    x = torch.nn.BatchNorm2d(h)(x)
    

    and

    import torch
    n, c, h, w = 32, 1, 64, 512
    x = torch.rand(n,c,h,w)
    x = torch.nn.BatchNorm2d(c)(x)
    

    I hope this helps you. Thanks!