I'm trying to understanding how torch.nn.LayerNorm
works in a nlp model. Asuming the input data is a batch of sequence of word embeddings:
batch_size, seq_size, dim = 2, 3, 4
embedding = torch.randn(batch_size, seq_size, dim)
print("x: ", embedding)
layer_norm = torch.nn.LayerNorm(dim)
print("y: ", layer_norm(embedding))
# outputs:
"""
x: tensor([[[ 0.5909, 0.1326, 0.8100, 0.7631],
[ 0.5831, -1.7923, -0.1453, -0.6882],
[ 1.1280, 1.6121, -1.2383, 0.2150]],
[[-0.2128, -0.5246, -0.0511, 0.2798],
[ 0.8254, 1.2262, -0.0252, -1.9972],
[-0.6092, -0.4709, -0.8038, -1.2711]]])
y: tensor([[[ 0.0626, -1.6495, 0.8810, 0.7060],
[ 1.2621, -1.4789, 0.4216, -0.2048],
[ 0.6437, 1.0897, -1.5360, -0.1973]],
[[-0.2950, -1.3698, 0.2621, 1.4027],
[ 0.6585, 0.9811, -0.0262, -1.6134],
[ 0.5934, 1.0505, -0.0497, -1.5942]]],
grad_fn=<NativeLayerNormBackward0>)
"""
From the document's description, my understanding is that the mean and std are computed by all embedding values per sample. So I try to compute y[0, 0, :]
manually:
mean = torch.mean(embedding[0, :, :])
std = torch.std(embedding[0, :, :])
print((embedding[0, 0, :] - mean) / std)
which gives tensor([ 0.4310, -0.0319, 0.6523, 0.6050])
and that's not the right output. I want to know what is the right way to compute y[0, 0, :]
?
Pytorch layer norm states mean and std calculated over last D dimensions. Based on this as I expect for (batch_size, seq_size, embedding_dim)
here calculation should be over (seq_size, embedding_dim)
for layer norm as last 2 dimensions excluding batch dim.
A similar question and answer with layer norm implementation can be found here, layer Normalization in pytorch?.
In some paper below it shows different layer norm application in NLP.
Layer Normalization (LN) operates along the channel dimension
LN computes µ and σ along the (C, H, W) axes for each sample.
In pytorch doc for NLP 3d tensor example mean and std instead are calculated over only last dim embedding_dim
.
In this paper it shows similar to pytorch doc example,
almost all NLP tasks take variable length sequences as input, which is very suitable for LN that only calculates statistics in the channel dimension without involving the batch and sequence length dimension.
Example shown in Another paper,
LN normalizes across the channel/feature dimension as shown in Figure 1.
import torch
batch_size, seq_size, dim = 2, 3, 4
last_dims = 4
embedding = torch.randn(batch_size, seq_size, dim)
print("x: ", embedding)
layer_norm = torch.nn.LayerNorm(last_dims, elementwise_affine = False)
layer_norm_out = layer_norm(embedding)
print("y: ", layer_norm_out)
eps: float = 0.00001
mean = torch.mean(embedding[0, :, :], dim=(-1), keepdim=True)
var = torch.square(embedding[0, :, :] - mean).mean(dim=(-1), keepdim=True)
y_custom = (embedding[0, :, :] - mean) / torch.sqrt(var + eps)
print("y_custom: ", y_custom)
assert torch.allclose(layer_norm_out[0], y_custom), 'Tensors do not match.'
eps: float = 0.00001
mean = torch.mean(embedding[1, :, :], dim=(-1), keepdim=True)
var = torch.square(embedding[1, :, :] - mean).mean(dim=(-1), keepdim=True)
y_custom = (embedding[1, :, :] - mean) / torch.sqrt(var + eps)
print("y_custom: ", y_custom)
assert torch.allclose(layer_norm_out[1], y_custom), 'Tensors do not match.'
x: tensor([[[-0.0594, -0.8702, -1.9837, 0.2914],
[-0.4774, 1.0372, 0.6425, -1.1357],
[ 0.3872, -0.9190, -0.5774, 0.3281]],
[[-0.5548, 0.0815, 0.2333, 0.3569],
[ 1.0380, -0.1756, -0.7417, 2.2930],
[-0.0075, -0.3623, 1.9310, -0.7043]]])
y: tensor([[[ 0.6813, -0.2454, -1.5180, 1.0822],
[-0.5700, 1.1774, 0.7220, -1.3295],
[ 1.0285, -1.2779, -0.6747, 0.9241]],
[[-1.6638, 0.1490, 0.5814, 0.9334],
[ 0.3720, -0.6668, -1.1513, 1.4462],
[-0.2171, -0.5644, 1.6809, -0.8994]]])
y_custom: tensor([[ 0.6813, -0.2454, -1.5180, 1.0822],
[-0.5700, 1.1774, 0.7220, -1.3295],
[ 1.0285, -1.2779, -0.6747, 0.9241]])
y_custom: tensor([[-1.6638, 0.1490, 0.5814, 0.9334],
[ 0.3720, -0.6668, -1.1513, 1.4462],
[-0.2171, -0.5644, 1.6809, -0.8994]])
import torch
batch_size, c, h, w = 2, 3, 2, 4
last_dims = [c, h, w]
embedding = torch.randn(batch_size, c, h, w)
print("x: ", embedding)
layer_norm = torch.nn.LayerNorm(last_dims, elementwise_affine = False)
layer_norm_out = layer_norm(embedding)
print("y: ", layer_norm_out)
eps: float = 0.00001
mean = torch.mean(embedding[0, :, :], dim=(-3, -2, -1), keepdim=True)
var = torch.square(embedding[0, :, :] - mean).mean(dim=(-3, -2, -1), keepdim=True)
y_custom = (embedding[0, :, :] - mean) / torch.sqrt(var + eps)
print("y_custom: ", y_custom)
assert torch.allclose(layer_norm_out[0], y_custom), 'Tensors do not match.'
eps: float = 0.00001
mean = torch.mean(embedding[1, :, :], dim=(-3, -2, -1), keepdim=True)
var = torch.square(embedding[1, :, :] - mean).mean(dim=(-3, -2, -1), keepdim=True)
y_custom = (embedding[1, :, :] - mean) / torch.sqrt(var + eps)
print("y_custom: ", y_custom)
assert torch.allclose(layer_norm_out[1], y_custom), 'Tensors do not match.'
x: tensor([[[[ 1.0902, -0.8648, 1.5785, 0.3087],
[ 0.0249, -1.3477, -0.9565, -1.5024]],
[[ 1.8024, -0.2894, 0.7284, 0.7822],
[ 1.4385, -0.2848, -0.3114, 0.4633]],
[[ 0.9061, 0.3066, 0.9916, 0.9284],
[ 0.3356, 0.9162, -0.4579, 1.0669]]],
[[[-0.8292, 0.9111, -0.7307, -1.1003],
[ 0.3441, -1.9823, 0.1313, 0.2048]],
[[-0.2838, 0.1147, -0.1605, -0.4637],
[-2.1343, -0.4402, 1.6685, 0.4455]],
[[ 0.6895, -2.7331, 1.1693, -0.6999],
[-0.3497, -0.2942, -0.0028, -1.3541]]]])
y: tensor([[[[ 0.8653, -1.3279, 1.4131, -0.0114],
[-0.3298, -1.8697, -1.4309, -2.0433]],
[[ 1.6643, -0.6824, 0.4594, 0.5198],
[ 1.2560, -0.6772, -0.7071, 0.1619]],
[[ 0.6587, -0.0137, 0.7547, 0.6838],
[ 0.0188, 0.6701, -0.8715, 0.8392]]],
[[[-0.4938, 1.2220, -0.3967, -0.7610],
[ 0.6629, -1.6306, 0.4531, 0.5256]],
[[ 0.0439, 0.4368, 0.1655, -0.1335],
[-1.7805, -0.1103, 1.9686, 0.7629]],
[[ 1.0035, -2.3707, 1.4764, -0.3663],
[-0.0211, 0.0337, 0.3210, -1.0112]]]])
y_custom: tensor([[[ 0.8653, -1.3279, 1.4131, -0.0114],
[-0.3298, -1.8697, -1.4309, -2.0433]],
[[ 1.6643, -0.6824, 0.4594, 0.5198],
[ 1.2560, -0.6772, -0.7071, 0.1619]],
[[ 0.6587, -0.0137, 0.7547, 0.6838],
[ 0.0188, 0.6701, -0.8715, 0.8392]]])
y_custom: tensor([[[-0.4938, 1.2220, -0.3967, -0.7610],
[ 0.6629, -1.6306, 0.4531, 0.5256]],
[[ 0.0439, 0.4368, 0.1655, -0.1335],
[-1.7805, -0.1103, 1.9686, 0.7629]],
[[ 1.0035, -2.3707, 1.4764, -0.3663],
[-0.0211, 0.0337, 0.3210, -1.0112]]])
from typing import Union, List
import torch
batch_size, seq_size, embed_dim = 2, 3, 4
embedding = torch.randn(batch_size, seq_size, embed_dim)
print("x: ", embedding)
print(embedding.shape)
print()
layer_norm = torch.nn.LayerNorm(embed_dim, elementwise_affine=False)
layer_norm_output = layer_norm(embedding)
print("y: ", layer_norm_output)
print(layer_norm_output.shape)
print()
def custom_layer_norm(
x: torch.Tensor, dim: Union[int, List[int]] = -1, eps: float = 0.00001
) -> torch.Tensor:
mean = torch.mean(x, dim=(dim,), keepdim=True)
var = torch.square(x - mean).mean(dim=(dim,), keepdim=True)
return (x - mean) / torch.sqrt(var + eps)
custom_layer_norm_output = custom_layer_norm(embedding)
print("y_custom: ", custom_layer_norm_output)
print(custom_layer_norm_output.shape)
assert torch.allclose(layer_norm_output, custom_layer_norm_output), 'Tensors do not match.'
x: tensor([[[-0.4808, -0.1981, 0.4538, -1.2653],
[ 0.3578, 0.6592, 0.2161, 0.3852],
[ 1.2184, -0.4238, -0.3415, -0.3487]],
[[ 0.9874, -1.7737, 0.1886, 0.0448],
[-0.5162, 0.7872, -0.3433, -0.3266],
[-0.5459, -0.0371, 1.2625, -1.6030]]])
torch.Size([2, 3, 4])
y: tensor([[[-0.1755, 0.2829, 1.3397, -1.4471],
[-0.2916, 1.5871, -1.1747, -0.1208],
[ 1.7301, -0.6528, -0.5334, -0.5439]],
[[ 1.1142, -1.6189, 0.3235, 0.1812],
[-0.8048, 1.7141, -0.4709, -0.4384],
[-0.3057, 0.1880, 1.4489, -1.3312]]])
torch.Size([2, 3, 4])
y_custom: tensor([[[-0.1755, 0.2829, 1.3397, -1.4471],
[-0.2916, 1.5871, -1.1747, -0.1208],
[ 1.7301, -0.6528, -0.5334, -0.5439]],
[[ 1.1142, -1.6189, 0.3235, 0.1812],
[-0.8048, 1.7141, -0.4709, -0.4384],
[-0.3057, 0.1880, 1.4489, -1.3312]]])
torch.Size([2, 3, 4])