Search code examples
pytorch

What is T in nn.Linear equation?


I'm trying to understand the PyTorch's nn.Linear. I get that it applies a linear transformation to the input, but in the docs they specify the equation being used as y= xA^T + b.

This reminds my of y = (x[1] * w[1]) + (x[2] * w[2]) + b. Is that at all what's happening here?

Also, here is my current understanding of the variables in this equation, is this correct?

x = input A = weight (I think) T = Not sure b = bias


Solution

  • The T stands for the transpose operation. For reasons that aren't entirely clear, pytorch stores the transpose of the weight matrix for linear layers. You can find some discussion of this here, but it seems to be a legacy thing.

    layer = nn.Linear(64, 128)
    layer.weight.shape
    > torch.Size([128, 64]) # we would expect (64, 128) but we get the transpose (128, 64)
    
    x = torch.randn(8, 64) # random input
    
    # nn.Linear computes `xA^T + b`
    (([email protected]) + layer.bias == layer(x)).all()
    > tensor(True)