Search code examples
validationplotpytorchtraining-datamean-square-error

plot training and validation loss in pytorch


I am using pytorch to train my CNN network. I want to plot my training and validation loss curves to visulize the model performance. How can I plot two curves?

I have below code

# create a function (this my favorite choice)
def RMSELoss(predicted,target):
    return torch.sqrt(torch.mean((predicted-target)**2))

criterion = RMSELoss

# loss = torch.sqrt(criterion(x, y))
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
epochs = 300

n_total_steps = len(train_dataset)

trainingEpoch_loss = []
validationEpoch_loss = []

for epoch in range(epochs):
    step_loss = []
    model.train()
    for i, data in enumerate(train_dataset):
        feature,target = data['data'].type(torch.FloatTensor),torch.tensor(data['target']).type(torch.FloatTensor)
         
        # Clear the gradients
        optimizer.zero_grad()
        # Forward Pass
        outputs = model(feature)
        # Find the Loss
        training_loss = criterion(outputs, target)
        # Calculate gradients
        training_loss.backward()
        # Update Weights
        optimizer.step()
        # Calculate Loss
        step_loss.append(training_loss.item())
        if (i+1) % 1 == 0:
            print (f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{n_total_steps}], Loss: {training_loss.item():.4f}')
    trainingEpoch_loss.append(np.array(step_loss).mean())
 
    model.eval()     # Optional when not using Model Specific layer
    for i, data in enumerate(val_dataset):
        validationStep_loss = []
        feature,target = data['data'].type(torch.FloatTensor),torch.tensor(data['target']).type(torch.FloatTensor)
        
        # Forward Pass
        outputs = model(feature)
        # Find the Loss
        validation_loss = criterion(outputs, target)
        # Calculate Loss
        validationStep_loss.append(validation_loss.item())
    validationEpoch_loss.append(np.array(validationStep_loss).mean())

Can you let me know if i am doing right or not? Also please let me know how to plot training and validation loss?


Solution

  • you are correct to collect your epoch losses in trainingEpoch_loss and validationEpoch_loss lists. Now, after the training, add code to plot the losses:

    from matplotlib import pyplot as plt
    plt.plot(trainingEpoch_loss, label='train_loss')
    plt.plot(validationEpoch_loss,label='val_loss')
    plt.legend()
    plt.show
    

    read matplotlib docs for more fancly plot features.