Search code examples
optimizationpytorchgradient-descentconvergence

Weak optimizers in Pytorch


Consider a simple line fitting a * x + b = x, where a, b are the optimized parameters and x is the observed vector given by

import torch
X = torch.randn(1000,1,1)

One can immediately see that the exact solution is a=1, b=0 for any x and it can be found as easily as:

import numpy as np
np.polyfit(X.numpy().flatten(), X.numpy().flatten(), 1)

I am trying now to find this solution by means of gradient descent in PyTorch, where the mean square error is used as an optimization criterion.

import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
from torch.optim import Adam, SGD, Adagrad, ASGD 

X = torch.randn(1000,1,1) # Sample data

class SimpleNet(nn.Module): # Trivial neural network containing two weights
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.f1 = nn.Linear(1,1)

    def forward(self, x):
        x = self.f1(x)
        return x

# Testing default setting of 3 basic optimizers

K = 500
net = SimpleNet() 
optimizer = Adam(params=net.parameters())
Adam_losses = []
optimizer.zero_grad()   # zero the gradient buffers
for k in range(K):
    for b in range(1): # single batch
        loss = torch.mean((net.forward(X[b,:,:]) - X[b,:, :])**2)
        loss.backward()
        optimizer.step()
        Adam_losses.append(float(loss.detach()))

net = SimpleNet()
optimizer = SGD(params=net.parameters(), lr=0.0001)
SGD_losses = []
optimizer.zero_grad()   # zero the gradient buffers
for k in range(K):
    for b in range(1): # single batch
        loss = torch.mean((net.forward(X[b,:,:]) - X[b,:, :])**2)
        loss.backward()
        optimizer.step()
        SGD_losses.append(float(loss.detach()))

net = SimpleNet()     
optimizer = Adagrad(params=net.parameters())
Adagrad_losses = []
optimizer.zero_grad()   # zero the gradient buffers
for k in range(K):
    for b in range(1): # single batch
        loss = torch.mean((net.forward(X[b,:,:]) - X[b,:, :])**2)
        loss.backward()
        optimizer.step()
        Adagrad_losses.append(float(loss.detach()))

The training progress in terms of loss evolution can be shown as Convergence process of 3 optimizer algorithms

What is surprising for me is a very slow convergence of the algorithms in default setting. I have thus 2 questions:

1) Is it possible to achieve an arbitrary small error (loss) purely by means of some Pytorch optimizer? Since the loss function is convex, it should be definitely possible, however, I am not able to figure out, how to achieve this using PyTorch. Note that the above 3 optimizers cannot do that - see the loss progress in log scale for 20000 iterations: Log-scale plot for 20000 training iterations

2) I am wondering how the optimizers can work well in complex examples, when they does not work well even in this extremely simple example. Or (and that is the second question) is it something wrong in their application above that I missed?


Solution

  • The place where you called zero_grad is wrong. During each epoch, gradient is added to the previous one and backpropagated. This makes the loss oscillate as it gets closer, but previous gradient throws it off of the solution again.

    Code below will easily perform the task:

    import torch
    
    X = torch.randn(1000,1,1)
    
    net = SimpleNet()
    optimizer = Adam(params=net.parameters())
    for epoch in range(EPOCHS):
        optimizer.zero_grad()  # zero the gradient buffers
        loss = torch.mean((net.forward(X) - X) ** 2)
        if loss < 1e-8:
            print(epoch, loss)
            break
        loss.backward()
        optimizer.step()
    

    1) Is it possible to achieve an arbitrary small error (loss) purely by means of some Pytorch optimizer?

    Yeah, precision above is reached in around ~1500 epochs, you can go lower up to the machine (float in this case) precision

    2) I am wondering how the optimizers can work well in complex examples, when they does not work well even in this extremely simple example.

    Currently, we don't have anything better (at least wide spread) for network optimization than first order methods. Those are used as it's much faster to calculate gradient than Hessians for higher order methods. And complex, non-convex functions may have a lot of minima which kinda fulfill the task we threw at it, there is no need for global minima per se (although they may under some conditions, see this paper).