Search code examples
deep-learningpytorchnormalizationtransformer-modelinference

Pytorch LayerNorm’s mean and std div are not fixed while inferencing


I’m working on recreating the input after torch.LayerNorm. As far as I know, the mean and standard deviation for LayerNorm are fixed during the inference phase. Therefore, I thought I could extract these factors and recreate the original input from the LayerNorm output.

I have successfully extracted the weight and bias, which are not necessarily identical to the mean and standard deviation because LayerNorm has its own weight and bias parameters. My weight and bias parameters are fused from various factors, but they successfully recreate the original input from the LayerNorm output.

However, when I applied these extracted weight and bias parameters to another input tensor and expected LayerNorm to work in the same way as with the previous input, I obtained a completely different output. I assumed that LayerNorm calculated new mean and standard deviation values for the second input, causing the difference. But I’m puzzled as to why LayerNorm computed the mean and standard deviation for the second input; they should have remained fixed during inference. below is my code

layer = layer().eval()
with torch.inference_mode():
    out = layer(input_data)

w = torch.zeros(len(out[0, :, 0]))
b = torch.zeros(len(out[0, :, 0]))

for i in range(len(out[0, :, 0])):
    w[i] = (input_data[0, i, 0] - input_data[0, i, 10]) / (out[0, i, 0] - out[0, i, 10])
    b[i] = (input_data[0, i, 0] * out[0, i, 10] - input_data[0, i, 10] * out[0, i, 0]) / (out[0, i, 10] - out[0, i, 0])

for i1 in range(len(input_remade[0, :, 0])):
    input_remade[0, i1, :] = out[0, i1, :] * w[i1] + b[i1]
print(torch.sum(input_remade - input_data))


input_data2 = torch.randn(1, 577, 768)
input_remade2 = torch.randn(1, 577, 768)
with torch.inference_mode():
    out2 = layer(input_data2)

for i1 in range(len(input_remade2[0, :, 0])):
    input_remade2[0, i1, :] = out2[0, i1, :] * w[i1] + b[i1]
print(torch.sum(input_remade2 - input_data2))

w1 = torch.zeros(len(out2[0, :, 0]))
b1 = torch.zeros(len(out2[0, :, 0]))

for i in range(len(out2[0, :, 0])):
    w1[i] = (input_data2[0, i, 0] - input_data2[0, i, 10]) / (out2[0, i, 0] - out2[0, i, 10])
    b1[i] = (input_data2[0, i, 0] * out2[0, i, 10] - input_data2[0, i, 10] * out2[0, i, 0]) / (out2[0, i, 10] - out2[0, i, 0])

for i1 in range(len(input_remade2[0, :, 0])):
    input_remade2[0, i1, :] = out2[0, i1, :] * w1[i1] + b1[i1]
print(torch.sum(input_remade2 - input_data2))
tensor(-0.0061)
tensor(1280.9966)
tensor(0.0014)

Or is there Any way to extracte fixed mean and standard deviation from LayerNorm layer?


Solution

  • From the pytorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html

    This layer uses statistics computed from input data in both training and evaluation modes.

    layernorm formu

    The E[x] and Var[x] are calculated on every input tensor. Only the γ and β are fixed at training time.

    Thus what you are observing is the correct and expected behavior.