Search code examples
nlphuggingface-transformersbert-language-modelnlp-question-answeringroberta-language-model

Input/output format for Fine Tuning Huggingface RobertaForQuestionAnswering


I'm trying to fine-tune "RobertaForQuestionAnswering" on my custom dataset and I'm confused about the input params it takes. Here's the sample code.

>>> from transformers import RobertaTokenizer, RobertaForQuestionAnswering
>>> import torch

>>> tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
>>> model = RobertaForQuestionAnswering.from_pretrained('roberta-base')

>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
>>> inputs = tokenizer(question, text, return_tensors='pt')
>>> start_positions = torch.tensor([1])
>>> end_positions = torch.tensor([3])

>>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
>>> loss = outputs.loss
>>> start_scores = outputs.start_logits
>>> end_scores = outputs.end_logits

I'm not able to understand variables start_positions & end_positions which are being given in the model as input and variables start_scores & end_scores that are being generated.


Solution

  • A Question Answering ot is basically a DL model that creates an answer by extracting part of the context (in you case what is called text). This means that the goal of the QAbot is to identify the start and the end of the answer.


    Basic functioning of a QAbot:

    First of all, every word of the question and context is tokenized. This means it is (possibily divided into characters/subwords and then) convertend into a number. It really depends on the type of tokenizer (which means it depends on the model you are using, since you will be using the same tokenizer - it's what the third line of your code is doing). I suggest this very useful guide.

    Then, the tokenized question + text are passed into the model which performs its internal operations. Remember when I told at the beginning that the model will identify the start and the end of the answer? Well, it does so by calculating for every token of the question + text the probability that that particular token is the start of the answer. This probabilities are the softmaxed version of the start_logits. After that, the same operations are performed for the end token.

    So, this is what start_scores and end_scores are: the pre-softmax scores that every token is start and end of the answer, respectively.


    So, what are start_position and stop_position?

    As stated here, they are:

    start_positions (torch.LongTensor of shape (batch_size,), optional) – Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (sequence_length). Position outside of the sequence are not taken into account for computing the loss.

    end_positions (torch.LongTensor of shape (batch_size,), optional) – Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (sequence_length). Position outside of the sequence are not taken into account for computing the loss.


    Moreover, the model you are using (roberta-base, see the model on the HuggingFace repository and the RoBERTa official paper) has NOT been fine-tuned for QuestionAnswering. It is "just" a model trained by using MaskedLanguageModeling, which means that the model has a general understanding of the english language, but it is not suitable for question asnwering. You can use it of course, but it would probably give non optimal results.

    I suggest you use the same model, inthe version specifically fine-tuned on QuestionAnswering: roberta-base-squad2, see it on HuggingFace.

    In practical terms, you have to replace the lines where you load the model and the tokenizer with:

    tokenizer = RobertaTokenizer.from_pretrained('roberta-base-squad2')
    model = RobertaForQuestionAnswering.from_pretrained('roberta-base-squad2')
    

    This will give much more accurate results.

    Bonus read: what fine-tuning is and how it works