Search code examples
pythondeep-learningpytorchscheduler

Can ReduceLrOnPlateau scheduler in pytorch use test set metric for decreasing learning rate?


Hi I am currently learning the use of scheduler in deep learning in pytroch. I came across the following code :

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets

# Set seed
torch.manual_seed(0)

# Where to add a new import
from torch.optim.lr_scheduler import ReduceLROnPlateau

'''
STEP 1: LOADING DATASET
'''

train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

test_dataset = dsets.MNIST(root='./data', 
                           train=False, 
                           transform=transforms.ToTensor())

'''
STEP 2: MAKING DATASET ITERABLE
'''

batch_size = 100
n_iters = 6000
num_epochs = n_iters / (len(train_dataset) / batch_size)
num_epochs = int(num_epochs)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

'''
STEP 3: CREATE MODEL CLASS
'''
class FeedforwardNeuralNetModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(FeedforwardNeuralNetModel, self).__init__()
        # Linear function
        self.fc1 = nn.Linear(input_dim, hidden_dim) 
        # Non-linearity
        self.relu = nn.ReLU()
        # Linear function (readout)
        self.fc2 = nn.Linear(hidden_dim, output_dim)  

    def forward(self, x):
        # Linear function
        out = self.fc1(x)
        # Non-linearity
        out = self.relu(out)
        # Linear function (readout)
        out = self.fc2(out)
        return out
'''
STEP 4: INSTANTIATE MODEL CLASS
'''
input_dim = 28*28
hidden_dim = 100
output_dim = 10

model = FeedforwardNeuralNetModel(input_dim, hidden_dim, output_dim)

'''
STEP 5: INSTANTIATE LOSS CLASS
'''
criterion = nn.CrossEntropyLoss()


'''
STEP 6: INSTANTIATE OPTIMIZER CLASS
'''
learning_rate = 0.1

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)

'''
STEP 7: INSTANTIATE STEP LEARNING SCHEDULER CLASS
'''
# lr = lr * factor 
# mode='max': look for the maximum validation accuracy to track
# patience: number of epochs - 1 where loss plateaus before decreasing LR
        # patience = 0, after 1 bad epoch, reduce LR
# factor = decaying factor
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=0, verbose=True)

'''
STEP 7: TRAIN THE MODEL
'''
iter = 0
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # Load images as Variable
        images = images.view(-1, 28*28).requires_grad_()

        # Clear gradients w.r.t. parameters
        optimizer.zero_grad()

        # Forward pass to get output/logits
        outputs = model(images)

        # Calculate Loss: softmax --> cross entropy loss
        loss = criterion(outputs, labels)

        # Getting gradients w.r.t. parameters
        loss.backward()

        # Updating parameters
        optimizer.step()

        iter += 1

        if iter % 500 == 0:
            # Calculate Accuracy         
            correct = 0
            total = 0
            # Iterate through test dataset
            for images, labels in test_loader:
                # Load images to a Torch Variable
                images = images.view(-1, 28*28)

                # Forward pass only to get logits/output
                outputs = model(images)

                # Get predictions from the maximum value
                _, predicted = torch.max(outputs.data, 1)

                # Total number of labels
                total += labels.size(0)

                # Total correct predictions
                # Without .item(), it is a uint8 tensor which will not work when you pass this number to the scheduler
                correct += (predicted == labels).sum().item()

            accuracy = 100 * correct / total

            # Print Loss
            # print('Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.data[0], accuracy))

    # Decay Learning Rate, pass validation accuracy for tracking at every epoch
    print('Epoch {} completed'.format(epoch))
    print('Loss: {}. Accuracy: {}'.format(loss.item(), accuracy))
    print('-'*20)
    scheduler.step(accuracy)

I am using the above strategy. The only thing that I am not able to understand is that how are they using test data to step up the accuracy and decreasing learning rate on the basis of that via scheduler? It is the last line of the code. Can we during training show the test accuracy to the scheduler and ask that it to reduce the learning rate? I found the similar thing on github resnet main.py too. Can someone please clarify ?


Solution

  • I think there might be some confusion regarding the term test here.

    Difference between test and validation data

    What the code actually refers to by test is the validation set not the actual test set. The difference is that the validation set is used during training to see how well the model generalizes. Normally people just cut off a part of the training data and use that for validation. To me it seems like your code is using the same data for training and validation but that's just my assumption because I don't know what ./data looks like.

    To work in a strictly scientific way, your model should never see actual test data during training, only training and validation. This way we can assess the models actual ability to generalize on unseen data after training.

    Reducing learning rate based on validation accuracy

    The reason why you use validation data (called test data in your case) to reduce the learning rate is probably because if you did this using the actual training data and training accuracy the model is more likely to overfit.

    Why?

    When you are on a plateau of the training accuracy it does not necessarily imply that it's a plateau of the validation accuracy and the other way round. Meaning you could be stepping in a promising direction regarding the validation accuracy (and thus in a direction of parameters that generalize well) and suddenly you reduce or increase the learning rate because there was a plateau (or non) in the training accuracy.

    Even when scheduling the learning rate based on the validation data, depending on your task at hand and the model's capacity, you're likely to run into the overfitting problem at least at one point when working with deep neural nets. If you now use the training data for scheduling, you're allowing the net to reach even better spots on the training data which will likely be even worse on the validation data.