Search code examples
pytorchlinear-regression

Output of the model depends on the shape of the weights tensor


I want to train the model to sum the three inputs. So it is as simple as possible.

Firstly the weights are initialized randomly. It produces bad error estimate (approx. 0.5)

Then I initialize the weights with zeros. There are two options:

  1. the shape of the weights tensor is [1, 3]
  2. the shape of the weights tensor is [3]

When I choose the 1st option the model still works bad and can't learn this simple formula.

When I choose the 2nd option it works perfect with the error of 10e-12.

Why the result depends on the shape of the weights? Why do I need to initialize the model with zeros to solve this simple problem?

    import torch
    from torch.nn import Sequential as Seq, Linear as Lin
    from torch.optim.lr_scheduler import ReduceLROnPlateau
    
    X = torch.rand((1024, 3))
    y = (X[:,0] + X[:,1] + X[:,2])
    m = Seq(Lin(3, 1, bias=False))
    
    # 1 option
    m[0].weight = torch.nn.parameter.Parameter(torch.tensor([[0, 0, 0]], dtype=torch.float))
    
    # 2 option
    #m[0].weight = torch.nn.parameter.Parameter(torch.tensor([0, 0, 0], dtype=torch.float))
    
    optim = torch.optim.SGD(m.parameters(), lr=10e-2)
    scheduler = ReduceLROnPlateau(optim, 'min', factor=0.5, patience=20, verbose=True)
    mse = torch.nn.MSELoss()
    for epoch in range(500):
        optim.zero_grad()
        out = m(X)
        loss = mse(out, y)
        loss.backward()
        optim.step()
        if epoch % 20 == 0:
            print(loss.item())
        scheduler.step(loss)


Solution

  • First option doesn't learning because it fails with broadcasting: while out.shape == (1024, 1) corresponding targets y has shape of (1024, ). MSELoss, as expected, computes mean of tensor (out - y)^2, which in this case has shape (1024, 1024), clearly wrong objective for this task. At the same time, after applying 2-nd option tensor (out - y)^2 has size (1024, ) and mean of it corresponds to actual mse. Default approach, without explicit changing weights shape (through option 1 and 2), would work if set target shape to (1024, 1) for example by y = y.unsqueeze(-1) after definition of y.