Search code examples
pythonpytorchnormalization

Understanding torch.nn.LayerNorm in nlp


I'm trying to understanding how torch.nn.LayerNorm works in a nlp model. Asuming the input data is a batch of sequence of word embeddings:

batch_size, seq_size, dim = 2, 3, 4
embedding = torch.randn(batch_size, seq_size, dim)
print("x: ", embedding)

layer_norm = torch.nn.LayerNorm(dim)
print("y: ", layer_norm(embedding))

# outputs:
"""
x:  tensor([[[ 0.5909,  0.1326,  0.8100,  0.7631],
         [ 0.5831, -1.7923, -0.1453, -0.6882],
         [ 1.1280,  1.6121, -1.2383,  0.2150]],

        [[-0.2128, -0.5246, -0.0511,  0.2798],
         [ 0.8254,  1.2262, -0.0252, -1.9972],
         [-0.6092, -0.4709, -0.8038, -1.2711]]])
y:  tensor([[[ 0.0626, -1.6495,  0.8810,  0.7060],
         [ 1.2621, -1.4789,  0.4216, -0.2048],
         [ 0.6437,  1.0897, -1.5360, -0.1973]],

        [[-0.2950, -1.3698,  0.2621,  1.4027],
         [ 0.6585,  0.9811, -0.0262, -1.6134],
         [ 0.5934,  1.0505, -0.0497, -1.5942]]],
       grad_fn=<NativeLayerNormBackward0>)
"""

From the document's description, my understanding is that the mean and std are computed by all embedding values per sample. So I try to compute y[0, 0, :] manually:

mean = torch.mean(embedding[0, :, :])
std = torch.std(embedding[0, :, :])
print((embedding[0, 0, :] - mean) / std)

which gives tensor([ 0.4310, -0.0319, 0.6523, 0.6050]) and that's not the right output. I want to know what is the right way to compute y[0, 0, :]?


Solution

  • Pytorch layer norm states mean and std calculated over last D dimensions. Based on this as I expect for (batch_size, seq_size, embedding_dim) here calculation should be over (seq_size, embedding_dim) for layer norm as last 2 dimensions excluding batch dim.

    A similar question and answer with layer norm implementation can be found here, layer Normalization in pytorch?.

    In some paper below it shows different layer norm application in NLP.

    Explanation of Intance vs Layer vs Group Norm

    enter image description here

    From group norm paper

    Layer Normalization (LN) operates along the channel dimension

    LN computes µ and σ along the (C, H, W) axes for each sample.

    enter image description here

    Different Application Example

    In pytorch doc for NLP 3d tensor example mean and std instead are calculated over only last dim embedding_dim.

    In this paper it shows similar to pytorch doc example,

    almost all NLP tasks take variable length sequences as input, which is very suitable for LN that only calculates statistics in the channel dimension without involving the batch and sequence length dimension.

    enter image description here

    Example shown in Another paper,

    LN normalizes across the channel/feature dimension as shown in Figure 1.

    enter image description here

    Manual Layer Norm with only Embed Dim

    import torch
    
    batch_size, seq_size, dim = 2, 3, 4
    last_dims = 4
    
    embedding = torch.randn(batch_size, seq_size, dim)
    print("x: ", embedding)
    
    layer_norm = torch.nn.LayerNorm(last_dims, elementwise_affine = False)
    layer_norm_out = layer_norm(embedding)
    print("y: ", layer_norm_out)
    
    eps: float = 0.00001
    mean = torch.mean(embedding[0, :, :], dim=(-1), keepdim=True)
    var = torch.square(embedding[0, :, :] - mean).mean(dim=(-1), keepdim=True)
    y_custom = (embedding[0, :, :] - mean) / torch.sqrt(var + eps)
    print("y_custom: ", y_custom)
    assert torch.allclose(layer_norm_out[0], y_custom), 'Tensors do not match.'
    
    eps: float = 0.00001
    mean = torch.mean(embedding[1, :, :], dim=(-1), keepdim=True)
    var = torch.square(embedding[1, :, :] - mean).mean(dim=(-1), keepdim=True)
    y_custom = (embedding[1, :, :] - mean) / torch.sqrt(var + eps)
    print("y_custom: ", y_custom)
    assert torch.allclose(layer_norm_out[1], y_custom), 'Tensors do not match.'
    

    Output

    x:  tensor([[[-0.0594, -0.8702, -1.9837,  0.2914],
             [-0.4774,  1.0372,  0.6425, -1.1357],
             [ 0.3872, -0.9190, -0.5774,  0.3281]],
    
            [[-0.5548,  0.0815,  0.2333,  0.3569],
             [ 1.0380, -0.1756, -0.7417,  2.2930],
             [-0.0075, -0.3623,  1.9310, -0.7043]]])
    y:  tensor([[[ 0.6813, -0.2454, -1.5180,  1.0822],
             [-0.5700,  1.1774,  0.7220, -1.3295],
             [ 1.0285, -1.2779, -0.6747,  0.9241]],
    
            [[-1.6638,  0.1490,  0.5814,  0.9334],
             [ 0.3720, -0.6668, -1.1513,  1.4462],
             [-0.2171, -0.5644,  1.6809, -0.8994]]])
    y_custom:  tensor([[ 0.6813, -0.2454, -1.5180,  1.0822],
            [-0.5700,  1.1774,  0.7220, -1.3295],
            [ 1.0285, -1.2779, -0.6747,  0.9241]])
    y_custom:  tensor([[-1.6638,  0.1490,  0.5814,  0.9334],
            [ 0.3720, -0.6668, -1.1513,  1.4462],
            [-0.2171, -0.5644,  1.6809, -0.8994]])
    

    Manual Layer Norm over 4D Tensor

    import torch
    
    batch_size, c, h, w = 2, 3, 2, 4
    last_dims = [c, h, w]
    
    embedding = torch.randn(batch_size, c, h, w)
    print("x: ", embedding)
    
    layer_norm = torch.nn.LayerNorm(last_dims, elementwise_affine = False)
    layer_norm_out = layer_norm(embedding)
    print("y: ", layer_norm_out)
    
    
    eps: float = 0.00001
    mean = torch.mean(embedding[0, :, :], dim=(-3, -2, -1), keepdim=True)
    var = torch.square(embedding[0, :, :] - mean).mean(dim=(-3, -2, -1), keepdim=True)
    y_custom = (embedding[0, :, :] - mean) / torch.sqrt(var + eps)
    print("y_custom: ", y_custom)
    assert torch.allclose(layer_norm_out[0], y_custom), 'Tensors do not match.'
    
    eps: float = 0.00001
    mean = torch.mean(embedding[1, :, :], dim=(-3, -2, -1), keepdim=True)
    var = torch.square(embedding[1, :, :] - mean).mean(dim=(-3, -2, -1), keepdim=True)
    y_custom = (embedding[1, :, :] - mean) / torch.sqrt(var + eps)
    print("y_custom: ", y_custom)
    assert torch.allclose(layer_norm_out[1], y_custom), 'Tensors do not match.'
    

    Output

    x:  tensor([[[[ 1.0902, -0.8648,  1.5785,  0.3087],
              [ 0.0249, -1.3477, -0.9565, -1.5024]],
    
             [[ 1.8024, -0.2894,  0.7284,  0.7822],
              [ 1.4385, -0.2848, -0.3114,  0.4633]],
    
             [[ 0.9061,  0.3066,  0.9916,  0.9284],
              [ 0.3356,  0.9162, -0.4579,  1.0669]]],
    
    
            [[[-0.8292,  0.9111, -0.7307, -1.1003],
              [ 0.3441, -1.9823,  0.1313,  0.2048]],
    
             [[-0.2838,  0.1147, -0.1605, -0.4637],
              [-2.1343, -0.4402,  1.6685,  0.4455]],
    
             [[ 0.6895, -2.7331,  1.1693, -0.6999],
              [-0.3497, -0.2942, -0.0028, -1.3541]]]])
    y:  tensor([[[[ 0.8653, -1.3279,  1.4131, -0.0114],
              [-0.3298, -1.8697, -1.4309, -2.0433]],
    
             [[ 1.6643, -0.6824,  0.4594,  0.5198],
              [ 1.2560, -0.6772, -0.7071,  0.1619]],
    
             [[ 0.6587, -0.0137,  0.7547,  0.6838],
              [ 0.0188,  0.6701, -0.8715,  0.8392]]],
    
    
            [[[-0.4938,  1.2220, -0.3967, -0.7610],
              [ 0.6629, -1.6306,  0.4531,  0.5256]],
    
             [[ 0.0439,  0.4368,  0.1655, -0.1335],
              [-1.7805, -0.1103,  1.9686,  0.7629]],
    
             [[ 1.0035, -2.3707,  1.4764, -0.3663],
              [-0.0211,  0.0337,  0.3210, -1.0112]]]])
    y_custom:  tensor([[[ 0.8653, -1.3279,  1.4131, -0.0114],
             [-0.3298, -1.8697, -1.4309, -2.0433]],
    
            [[ 1.6643, -0.6824,  0.4594,  0.5198],
             [ 1.2560, -0.6772, -0.7071,  0.1619]],
    
            [[ 0.6587, -0.0137,  0.7547,  0.6838],
             [ 0.0188,  0.6701, -0.8715,  0.8392]]])
    y_custom:  tensor([[[-0.4938,  1.2220, -0.3967, -0.7610],
             [ 0.6629, -1.6306,  0.4531,  0.5256]],
    
            [[ 0.0439,  0.4368,  0.1655, -0.1335],
             [-1.7805, -0.1103,  1.9686,  0.7629]],
    
            [[ 1.0035, -2.3707,  1.4764, -0.3663],
             [-0.0211,  0.0337,  0.3210, -1.0112]]])
    

    Example of custom layer norm implementation

    from typing import Union, List
    
    import torch
    
    
    batch_size, seq_size, embed_dim = 2, 3, 4
    embedding = torch.randn(batch_size, seq_size, embed_dim)
    print("x: ", embedding)
    print(embedding.shape)
    print()
    
    
    layer_norm = torch.nn.LayerNorm(embed_dim, elementwise_affine=False)
    layer_norm_output = layer_norm(embedding)
    print("y: ", layer_norm_output)
    print(layer_norm_output.shape)
    print()
    
    
    def custom_layer_norm(
            x: torch.Tensor, dim: Union[int, List[int]] = -1, eps: float = 0.00001
    ) -> torch.Tensor:
        mean = torch.mean(x, dim=(dim,), keepdim=True)
        var = torch.square(x - mean).mean(dim=(dim,), keepdim=True)
        return (x - mean) / torch.sqrt(var + eps)
    
    
    custom_layer_norm_output = custom_layer_norm(embedding)
    print("y_custom: ", custom_layer_norm_output)
    print(custom_layer_norm_output.shape)
    
    assert torch.allclose(layer_norm_output, custom_layer_norm_output), 'Tensors do not match.'
    

    Output

    x:  tensor([[[-0.4808, -0.1981,  0.4538, -1.2653],
             [ 0.3578,  0.6592,  0.2161,  0.3852],
             [ 1.2184, -0.4238, -0.3415, -0.3487]],
    
            [[ 0.9874, -1.7737,  0.1886,  0.0448],
             [-0.5162,  0.7872, -0.3433, -0.3266],
             [-0.5459, -0.0371,  1.2625, -1.6030]]])
    torch.Size([2, 3, 4])
    
    y:  tensor([[[-0.1755,  0.2829,  1.3397, -1.4471],
             [-0.2916,  1.5871, -1.1747, -0.1208],
             [ 1.7301, -0.6528, -0.5334, -0.5439]],
    
            [[ 1.1142, -1.6189,  0.3235,  0.1812],
             [-0.8048,  1.7141, -0.4709, -0.4384],
             [-0.3057,  0.1880,  1.4489, -1.3312]]])
    torch.Size([2, 3, 4])
    
    y_custom:  tensor([[[-0.1755,  0.2829,  1.3397, -1.4471],
             [-0.2916,  1.5871, -1.1747, -0.1208],
             [ 1.7301, -0.6528, -0.5334, -0.5439]],
    
            [[ 1.1142, -1.6189,  0.3235,  0.1812],
             [-0.8048,  1.7141, -0.4709, -0.4384],
             [-0.3057,  0.1880,  1.4489, -1.3312]]])
    torch.Size([2, 3, 4])