Search code examples
deep-learningpytorchbatch-normalizationmlp

How to use PyTorch nn.BatchNorm1d to get equal normalization across features?


i would like to ask a question regarding the nn.BatchNorm1d in PyTorch.

I have one main tensor, which has shape [B, 3, N]. Then, i have two additional tensors which have shape [B, 3, V1] and [B, 3, V2]. I will concatenate the main tensor with the two tensors separately, to construct new tensors [B, 3, N+V1] and [B, 3, N+V2].

I pass my tensors to a plain MLP (consists of conv1d and batchnorm1d). Ideally, i want to predict something "point-wise", like no matter what the number of dimension 2, it has some consistent prediction only given the value. However, the batchnorm1d will have different results given input [B, 3, N+V1] and [B, 3, N+V2], while i am only focusing on first N points in 2nd dimension.

import torch
import torch.nn as nn

# nn.BatchNorm1d
B=2
dim=64
N=40000
V1=1000
v2=2000
torch.manual_seed(0)
x = torch.rand(B, dim, N)  # here imgs are flattened from 28x28
v1 = torch.rand(B, dim, V1)
v2 = torch.rand(B, dim, v2)

layer = nn.BatchNorm1d(dim)  # batch norm is done on channels

out2 = layer(torch.cat((x, v1), dim=2))
out3 = layer(torch.cat((x, v2), dim=2))

torch.equal(out2[:, :, :N], out3[:, :, :N])

Is there any possible way to have consistent prediction of first N points?


Solution

  • Is this more along the lines of what you're looking for? Normalizing just across the channels?

    out2 = torch.cat((x, v1), dim=2) / torch.linalg.norm(torch.cat((x, v1), dim=2),  dim=1, keepdim=True)
    out3 = torch.cat((x, v2), dim=2) / torch.linalg.norm(torch.cat((x, v2), dim=2),  dim=1, keepdim=True)
    
    torch.equal(out2[:, :, :N], out3[:, :, :N])
    # True
    

    I think if you want to do something like this within pytorch nn libraries you'll need to transpose your channels and feature dimensions that way you can use LayerNorm1d or InstanceNorm. See here for a nice visual example of the different normalization techniques

    Update answer:

    In case you want to use an nn module specifically. InstanceNorm or GroupNorm could also get you the response. However the number of channels now differs between the two so you'll need two distinct layers.

    layer1 = nn.GroupNorm(V1+N, V1+N)
    layer2 = nn.GroupNorm(V2+N, V2+N)
    out2 = layer1(torch.cat((x, v1), dim=2).transpose(1,2))
    out3 = layer2(torch.cat((x, v2), dim=2).transpose(1,2))
    
    torch.equal(out2[:, :N, :], out3[:, :N, :])
    True