Search code examples
pythondeep-learningpytorchpre-trained-modelnlp

How to access the predictions of pytorch classification model? (BERT)


I'm running this file: https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/run_classifier.py

This is the prediction code for one input batch:

  input_ids = input_ids.to(device)
  input_mask = input_mask.to(device)
  segment_ids = segment_ids.to(device)
  label_ids = label_ids.to(device)

  with torch.no_grad():
       logits = model(input_ids, segment_ids, input_mask, labels=None)

       loss_fct = CrossEntropyLoss()
       tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))

       eval_loss += tmp_eval_loss.mean().item()
       nb_eval_steps += 1
       if len(preds) == 0:
           preds.append(logits.detach().cpu().numpy())
       else:
           preds[0] = np.append(preds[0], logits.detach().cpu().numpy(), axis=0)

The task is a binary classification. I want to access the binary output.

I've tried this:

  curr_pred = logits.detach().cpu()

  if len(preds) == 0:
      preds.append(curr_pred.numpy())
  else:
      preds[0] = np.append(preds[0], curr_pred.numpy(), axis=0)

  probablities = curr_pred.softmax(1).numpy()[:, 1]

But the results seem weird. So I'm not sure if it's the correct way.

My hypothesis - I'm receiving the output of the last layer, therefore after softmax, it's the true probabilities (vector of dim 2 - the probability to the 1st and probability to the 2nd class.)


Solution

  • After looking at this part of the run_classifier.py code:

        # copied from the run_classifier.py code 
        eval_loss = eval_loss / nb_eval_steps
        preds = preds[0]
        if output_mode == "classification":
            preds = np.argmax(preds, axis=1)
        elif output_mode == "regression":
            preds = np.squeeze(preds)
        result = compute_metrics(task_name, preds, all_label_ids.numpy())
    

    You are just missing:

        preds = preds[0]
        preds = np.argmax(preds, axis=1)
    

    Then they just use preds to compute the accuracy as:

        def simple_accuracy(preds, labels):
             return (preds == labels).mean()