Search code examples
pythonpytorchfast-ai

Training results are different for Classification using Pytorch APIs and Fast-ai


I have two training python scripts. One using Pytorch's API for classification training and another one is using Fast-ai. Using Fast-ai has much better results.

Training outcomes are as follows.

Fastai
epoch     train_loss  valid_loss  accuracy  time    
0         0.205338    2.318084    0.466482  23:02                         
1         0.182328    0.041315    0.993334  22:51                         
2         0.112462    0.064061    0.988932  22:47                         
3         0.052034    0.044727    0.986920  22:45                         
4         0.178388    0.081247    0.980883  22:45                         
5         0.009298    0.011817    0.996730  22:44                         
6         0.004008    0.003211    0.999748  22:43 

Using Pytorch
Epoch [1/10], train_loss : 31.0000 , val_loss : 1.6594, accuracy: 0.3568
Epoch [2/10], train_loss : 7.0000 , val_loss : 1.7065, accuracy: 0.3723
Epoch [3/10], train_loss : 4.0000 , val_loss : 1.6878, accuracy: 0.3889
Epoch [4/10], train_loss : 3.0000 , val_loss : 1.7054, accuracy: 0.4066
Epoch [5/10], train_loss : 2.0000 , val_loss : 1.7154, accuracy: 0.4106
Epoch [6/10], train_loss : 2.0000 , val_loss : 1.7232, accuracy: 0.4144
Epoch [7/10], train_loss : 2.0000 , val_loss : 1.7125, accuracy: 0.4295
Epoch [8/10], train_loss : 1.0000 , val_loss : 1.7372, accuracy: 0.4343
Epoch [9/10], train_loss : 1.0000 , val_loss : 1.6871, accuracy: 0.4441
Epoch [10/10], train_loss : 1.0000 , val_loss : 1.7384, accuracy: 0.4552

Using Pytorch is not converging. I used the same network (Wideresnet22) and both are trained from scratch without pretrained model.

The network is here.

Training using Pytorch is here.

Using Fastai is as follows.

from fastai.basic_data import DataBunch
from fastai.train import Learner
from fastai.metrics import accuracy

#DataBunch takes data and internall create data loader
data = DataBunch.create(train_ds, valid_ds, bs=batch_size, path='./data')
#Learner uses Adam as default for learning
learner = Learner(data, model, loss_func=F.cross_entropy, metrics=[accuracy])
#Gradient is clipped
learner.clip = 0.1

#learner finds its learning rate
learner.lr_find()

learner.recorder.plot()

#Weight decay helps to lower down weight. Learn in https://towardsdatascience.com/
learner.fit_one_cycle(5, 5e-3, wd=1e-4)

What could be wrong in my training algorithm using Pytorch?


Solution

  • fastai is using a lot of tricks under the hood. A quick catch of what they're doing and you're not.

    Those are in the order that I think matters most, especially the first two should improve your scores.

    TLDR

    Use some scheduler (torch.optim.lr_scheduler.CyclicLR preferably) and AdamW instead of SGD.

    Longer version

    fit_one_cycle

    1 cycle policy by Leslie Smith is used in fastai. In PyTorch one can create similar routine using torch.optim.lr_scheduler.CyclicLR but that would require some manual setup.

    Basically it starts with lower learning rate, gradually increases up to 5e-3 in your case and comes back to lower learning rate again (making a cycle). You can adjust how the lr should raise and fall (in fastai it does so using cosine annealing IIRC).

    Your learning rate is too high at the beginning, some scheduler should help, test it out first of all.

    Optimizer

    In the provided code snippet you use torch.optim.SGD (as optim_fn is None and default is set) which is harder to setup correctly (usually).

    On the other hand, if you manage to set it up manually correctly, you might generalize better.

    Also fastai does not use Adam by default! It uses AdamW if true_wd is set (I think, it will be default in your case anyway, see source code). AdamW decouples weight decay from adaptive learning rate which should improve convergence (read here or original paper

    Number of epochs

    Set the same number of epochs if you want to compare both approaches, currently it's apple to oranges.

    Gradient clipping

    You do not clip gradient (it is commented out), might help or not depending on the task. Would not focus on that one for now tbh.

    Other tricks

    Read about Learner and fit_one_cycle and try to setup something similar in PyTorch (rough guidelines described above)

    Also you might use some form of data augmentation to improve the scores even further, but that's out of the question's scope I suppose.