Search code examples
scikit-learnpytorchprecision-recall

Precision,recall, F1 score with Sklearn on Pytorch


I've been looking through samples but am unable to understand how to integrate the precision, recall and f1 metrics for my model. My code is as follows:

for epoch in range(num_epochs):

#Calculate Accuracy (stack tutorial no n_total)
n_correct = 0
n_total = 0

for i, (words, labels) in enumerate(train_loader):
    words = words.to(device)
    labels = labels.to(dtype=torch.long).to(device)
    
    # Forward pass
    outputs = model(words)
    loss = criterion(outputs, labels)
    
    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    #feedforward tutorial solution
    _, predicted = torch.max(outputs, 1)
    n_correct += (predicted == labels).sum().item()
    n_total += labels.shape[0]

accuracy = 100 * n_correct/n_total

#Push to matplotlib
train_losses.append(loss.item())
train_epochs.append(epoch)
train_acc.append(accuracy)

#Loss and Accuracy
if (epoch+1) % 10 == 0:
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.2f}, Acc: {accuracy:.2f}')

Solution

  • Since you have the predicted and the labels variables, you can aggregate them during the epoch loop and convert them to numpy arrays to calculate the required metrics.

    At the beginning of the epoch, initialize two empty lists; one for true labels and one for ground truth labels.

    for epoch in range(num_epochs):
        predicted_labels, ground_truth_labels = [], []
        ...
    

    Then, keep appending the respective entries to each list during the epoch:

       ...
        _, predicted = torch.max(outputs, 1)
        n_correct += (predicted == labels).sum().item()
        
        # appending
        predicted_labels.append(predicted.cpu().detach().numpy())
        ground_truth_labels.append(labels.cpu().detach().numpy())
    
    ...
    

    Then, at the epoch end, you could use precision_recall_fscore_support with predicted_labels and ground_truth_labels as inputs.

    Notes:

    1. You'll probably have to refer something like this to flatten the above two lists.
    2. Read about torch.no_grad() to apply it as a good practice during the calculations of metrics.