Search code examples
pythonpython-3.xpytorchbert-language-modelhuggingface-transformers

Getting predict.proba from BERT classififer


I have a classifier on top of BERT, and I would like to see the predict probability for creating the ROC curve. How do I get the predict proba?. The predicted probas will be used to calculate the TPR FPR and threshold for ROC curve.
here is the code

class BertBinaryClassifier(nn.Module):
    def __init__(self, dropout=0.1):
        super(BertBinaryClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 1)
        self.sigmoid = nn.Sigmoid()
        
    
    def forward(self, tokens, masks=None):
        _, pooled_output = self.bert(tokens, attention_mask=masks, output_all_encoded_layers=False)
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        prediction = self.sigmoid(linear_output)
        return prediction
# Config setting
BATCH_SIZE = 4
EPOCHS = 5
# Making dataloaders
train_dataset =  torch.utils.data.TensorDataset(train_tokens_tensor, train_masks_tensor, train_y_tensor)
train_sampler =  torch.utils.data.RandomSampler(train_dataset)
train_dataloader =  torch.utils.data.DataLoader(train_dataset, sampler=train_sampler, batch_size=BATCH_SIZE)
test_dataset =  torch.utils.data.TensorDataset(test_tokens_tensor, test_masks_tensor, test_y_tensor)
test_sampler =  torch.utils.data.SequentialSampler(test_dataset)
test_dataloader =  torch.utils.data.DataLoader(test_dataset, sampler=test_sampler, batch_size=BATCH_SIZE)

bert_clf = BertBinaryClassifier()
bert_clf = bert_clf.cuda()
#wandb.watch(bert_clf)
optimizer = torch.optim.Adam(bert_clf.parameters(), lr=3e-6)

# training 
for epoch_num in range(EPOCHS):
    bert_clf.train()
    train_loss = 0
    for step_num, batch_data in enumerate(train_dataloader):
        token_ids, masks, labels = tuple(t for t in batch_data)
        token_ids, masks, labels = token_ids.to(device), masks.to(device), labels.to(device)
        preds = bert_clf(token_ids, masks)
        loss_func = nn.BCELoss()
        batch_loss = loss_func(preds, labels)
        train_loss += batch_loss.item()
        bert_clf.zero_grad()
        batch_loss.backward()
        optimizer.step()
        #wandb.log({"Training loss": train_loss})
        print('Epoch: ', epoch_num + 1)
        print("\r" + "{0}/{1} loss: {2} ".format(step_num, len(train_data) / BATCH_SIZE, train_loss / (step_num + 1)))

# evaluating on test
bert_clf.eval()
bert_predicted = []
all_logits = []
probs=[]
with torch.no_grad():
    test_loss = 0
    for step_num, batch_data in enumerate(test_dataloader):
        token_ids, masks, labels = tuple(t for t in batch_data)
        token_ids, masks, labels = token_ids.to(device), masks.to(device), labels.to(device)
        logits = bert_clf(token_ids, masks)
        pr=logits.ravel()
        probs+=pr
        loss_func = nn.BCELoss()
        loss = loss_func(logits, labels)
        test_loss += loss.item()
        numpy_logits = logits.cpu().detach().numpy()
        #print(numpy_logits)
        #wandb.log({"Testing loss": test_loss})
        bert_predicted += list(numpy_logits[:, 0] > 0.5)
        all_logits += list(numpy_logits[:, 0])

I am able to get the prediction score to calculate the accuracy or f1 score. But not the probability for creating ROC curve. Thanks


Solution

  • In your forward, you:

    def forward(self, tokens, masks=None):
        _, pooled_output = self.bert(...)             # Get output of BERT
        dropout_output = self.dropout(pooled_output)  
        linear_output = self.linear(dropout_output)   # Take linear combination of outputs
                                                      # (unconstrained score - "logits")
    
        prediction = self.sigmoid(linear_output)      # Normalise scores 
                                                      # (constrained between [0,1] - "probabilities")
        return prediction
    

    Hence the result of calling your model can be directly supplied to calculate the False Positive and True Positive rates e.g:

    from sklearn import metrics
    
    ...
    
    test_probs = bert_clf(token_ids, masks)
    fpr, tpr, thresholds = metrics.roc_curve(labels, test_probs)
    roc_auc = metrics.auc(fpr, tpr)