Search code examples
pytorchnlphuggingface-transformerslanguage-modelperplexity

Why is perplexity calculation giving different results for the same input?


I'm following Huggingface doc on calculating the perplexity of fixed-length models. I'm trying to verify that the formula works for various strings and I'm getting odd behavior. In particular, they mention

We don’t want the log-likelihood for the tokens we’re just treating as context to be included in our loss, so we can set these targets to -100 so that they are ignored

So given 2 different contexts but the same remaining tokens, the formula should return the same perplexity. However, it does not:

from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

context_1 = 'here is some context_1 and some more stuff'
context_2 = 'here is some context and some more stuff and more stuff aspodkaspd'
answer_1 = 'this is not the answer'

input_ids_wrong = tokenizer(context_1 + answer_1, return_tensors="pt").input_ids
input_ids_correct = tokenizer(context_2 + answer_1, return_tensors="pt").input_ids
context_1_tokens_length = len(tokenizer(context_1, return_tensors="pt").input_ids[0])
context_2_tokens_length = len(tokenizer(context_2, return_tensors="pt").input_ids[0])

target_ids_wrong = input_ids_wrong.clone()
target_ids_correct = input_ids_correct.clone()

target_ids_wrong[:, :context_1_tokens_length] = -100 
target_ids_correct[:, :context_2_tokens_length] = -100 

print('target_ids_wrong', target_ids_wrong)
print('target_ids_correct', target_ids_correct)

with torch.no_grad():
    outputs_wrong = model(input_ids_wrong, labels=target_ids_wrong)
    outputs_correct = model(input_ids_correct, labels=target_ids_correct)
    
    neg_log_likelihood_wrong = outputs_wrong.loss
    neg_log_likelihood_correct = outputs_correct.loss

    ppl_wrong = torch.exp(neg_log_likelihood_wrong)
    ppl_correct = torch.exp(neg_log_likelihood_correct)
    print('ppl_wrong', ppl_wrong)
    print('ppl_correct', ppl_correct)

Output:

    target_ids_wrong tensor([[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,   19,
               59,    8, 1525,    1]])
    target_ids_correct tensor([[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100,   19,   59,    8, 1525,    1]])
    ppl_wrong tensor(9.0377)
    ppl_correct tensor(21.1208)

I tried this with other models as well (e.g., gpt2 and sshleifer/tiny-gpt2) and got the same odd behavior. From the T5 doc they wrote

we must make sure that padding token id’s of the labels are not taken into account by the loss function. In PyTorch and Tensorflow, this can be done by replacing them with -100, which is the ignore_index of the CrossEntropyLoss. So I don't understand why it takes the pad token into account. They also wrote in the same link which for T5 is equal to 0 (i.e. the id of the pad token)

So I tried replacing the -100 with 0 and actually a got different perplexity score (still different than each other, but different than the -100). Which makes me think they don't actually ignore the -100 token for some reason.

Am I missing something?


Solution

  • I think you might not be thinking of "ignore the context" in the same whay that they are. When they want to context to be ignored, they effectively mean they want to compute the log probs for the answer conditioned on the context; e.g., they want something like P(answer|context_1) and P(answer|context_2) instead of P(context_1 + answer) or P(context_2 + answer). If you want to ignore the context entirely, that would be P(answer) - in which case just don't pass the context into the model.

    Basically, the probability of the answer SHOULD change when given different contexts - but you want only the conditional probability of the answer given the context, not the joint of the answer and context. You want "How likely is this answer given this context?", not "How likely am I to see this context and answer in general".

    Lastly, tokens with value of -100 are ignored by the cross entropy loss - that's why they're used, and why you get a different value if you set them to 0.