Below is the code for combining weight and bias into a single layer, I am not able to understand the line below, why we have to multiply weight transpose matrix with bais. I should just bias without weight because we are multiplying weight for getting final output3
combined_layer.bias.data = layer1.bias @ layer2.weight.t() + layer2.bias
# Create a single layer to replace the two linear layers
combined_layer = nn.Linear(input_size, output_size)
combined_layer.weight.data = layer2.weight @ layer1.weight
combined_layer.bias.data = layer1.bias @ layer2.weight.t() + layer2.bias //This should be just bias
outputs3 = inputs @ combined_layer.weight.t() + combined_layer.bias
Could anyone please help me in understanding this?
You simply need to expand the original equation of two Linear
layers i.e.
# out = layer2(layer1(x))
# given (x @ A + B) @ C + D
out = (x @ layer1.weight.t() + layer1.bias) @ layer2.weight.t() + layer2.bias
You can expand (x @ A + B) @ C + D = (x @ A @ C) + B @ C + D
out = x @ layer1.weight.t() @ layer2.weight.t() + layer1.bias @ layer2.weight.t() + layer2.bias
out = x @ (layer1.weight.t() @ layer2.weight.t()) + (layer1.bias @ layer2.weight.t() + layer2.bias)
# the above equation is x @ (A @ C) + B @ C + D
# now you can assume
combined_layer.weight = layer2.weight @ layer1.weight
combined_layer.bias = layer1.bias @ layer2.weight.t() + layer2.bias
# final output
out = x @ combined_layer.weight.t() + combined_layer.bias
Also, note that matrix multiplication transpose rule is also used here i.e.
transpose(A@B) = transpose(B) @ transpose(A)
That's why combined_layer.weight.t()
is multiplied by x as we didn't take transpose in layer2.weight @ layer1.weight
.