Search code examples
pythonbert-language-model

Type errors with BERT example


I'm new to BERT QA model & was trying to follow the example found in this article. The problem is when I run the code attached to the example it produces a Type error as follows TypeError: argmax(): argument 'input' (position 1) must be Tensor, not str.

Here is the code that I've tried running :

import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer

#Model
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

#Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

question = '''SAMPLE QUESTION"'''

paragraph = '''SAMPLE PARAGRAPH'''
            
encoding = tokenizer.encode_plus(text=question,text_pair=paragraph, add_special=True)

inputs = encoding['input_ids']  #Token embeddings
sentence_embedding = encoding['token_type_ids']  #Segment embeddings
tokens = tokenizer.convert_ids_to_tokens(inputs) #input tokens

start_scores, end_scores = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))

start_index = torch.argmax(start_scores)

end_index = torch.argmax(end_scores)

answer = ' '.join(tokens[start_index:end_index+1])

The issue appears at line 13 of this code where I'm trying to get the maximum element in start_scores saying that this is not a tensor. When I tried printing this variable it showed "start_logits" as a string. Does anyone know a solution to this issue?


Solution

  • So after referring to the BERT Documentation we identified that the model output object contains multiple properties not only start & end scores. Thus, we applied the following changes to the code.

    
    outputs = model(input_ids=torch.tensor([inputs]),token_type_ids=torch.tensor([sentence_embedding]))
    
    start_index = torch.argmax(outputs.start_logits)
    
    end_index = torch.argmax(outputs.end_logits)
    
    answer = ' '.join(tokens[start_index:end_index+1])
    

    Always refer to the documentation first :"D