I'm trying to understand the PyTorch's nn.Linear
. I get that it applies a linear transformation to the input, but in the docs they specify the equation being used as y= xA^T + b
.
This reminds my of y = (x[1] * w[1]) + (x[2] * w[2]) + b
. Is that at all what's happening here?
Also, here is my current understanding of the variables in this equation, is this correct?
x
= input
A
= weight (I think)
T
= Not sure
b
= bias
The T
stands for the transpose operation. For reasons that aren't entirely clear, pytorch stores the transpose of the weight matrix for linear layers. You can find some discussion of this here, but it seems to be a legacy thing.
layer = nn.Linear(64, 128)
layer.weight.shape
> torch.Size([128, 64]) # we would expect (64, 128) but we get the transpose (128, 64)
x = torch.randn(8, 64) # random input
# nn.Linear computes `xA^T + b`
(([email protected]) + layer.bias == layer(x)).all()
> tensor(True)