I try to count cross-entropy loss for text, using gpt-2. Take the idea from this article https://huggingface.co/docs/transformers/perplexity
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import torch
from tqdm import tqdm
model_id = "gpt2-large"
model = GPT2LMHeadModel.from_pretrained(model_id)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id, cache_dir='.')
encodings = tokenizer("She felt his demeanor was sweet and endearing.", return_tensors="pt")
max_length = model.config.n_positions
seq_len = encodings.input_ids.size(1)
target_ids = encodings.input_ids.clone()
#target_ids[:, :-seq_len] = -100 COMMENTED LINE
with torch.no_grad():
outputs = model(encodings.input_ids, labels=target_ids)
print(outputs.loss.item())
Status of commented line(commented/uncommented) does not matter at all. I always get 4.352320194244385 as print output. According to documentation https://huggingface.co/docs/transformers/v4.35.2/en/model_doc/gpt2#transformers.GPT2LMHeadModel :
labels (torch.LongTensor of shape (batch_size, sequence_length), optional) — Labels for language modeling. Note that the labels are shifted inside the model, i.e. you can set labels = input_ids. Indices are selected in [-100, 0, ..., config.vocab_size - 1]. All labels set to -100 are ignored (masked), the loss is only computed for labels in [0, ..., config.vocab_size - 1]
Why do I get the same result?
Lets first start with how decoder-only models roughly work in huggingface.
The main ingredients are
input_ids
, label_ids
and the attention_mask. The latter are not relevant here.
Lets say we have a input:
"The answer is:" and we want it to learn to answer with "42".
Translated to tokens "The answer is: 42" is [464, 3280, 318, 25, 5433]
(Where ":" is 25
and " 42" 5433
).
For the learning objective we want to learn "42". The targets are therefore:
label_ids=[-100, -100, -100, -100, 5433]
. So the model will not learn that "The answer" is followed by "is:".
The decoder-only model expect input and output to have the same shape.
In contrary, encoder-decoder models can have "The answer is:" as input and "42" as output.
-100
is the ignore_index
of the standard torch.nn.CrossEntropyLoss.
"Ignored" is here a better term than "masked", the latter implies that the model "does not see" the said input, or that the original got replaced by a special "<masked>" token.
As noted by the other answer and comment:
You do target_ids[:, :-seq_len] = -100
which in this case is
target_ids[:, :-10]
ALL BUT the 10 LAST elements.
As target_ids have length 10 nothing is changed and you get the same result.
:-seq_len
is correct here to use but you need to think a bit more ahead and what you actually want, therefore my introduction.
In your linked post the interesting part happens in the second iteration! In the first nothing is ignored.
Roughly what happens is:
max_length = 1024
stride = 512
end_loc = 1024
inputs_ids = tokens[0 : 1024]
target_ids = input_ids.clone()
target_ids[:-1024] = -100 # <--- nothing is changed here!
# this is equivalent to target_ids[:0] = -100 or target_ids[:-len(target_ids)] = -100
assert target_ids == input_ids # only for the first iteration.
trg_len = 1024 # NOTE 1024 not 512
prev_end_loc = 1024
The loss is calculated over ALL 1024 tokens.
begin_loc = 512
end_loc = 1536
trg_len = 1536 - 1024 = 512 # NOTE here it is 512
input_ids = tokens[512 : 1536] #NOTE: tokens 512-1024 we have SEEN already
target_ids = tokens[512 : 1536].clone() # tensor of length 1024
target_ids = target_ids[:-512] = -100 # the already SEEN tokens are ignored
Loss is calculated over the latter 512 tokens from all 1024 only.
So, only from the second iteration onward target_ids
is a tensor which is filled by -100 for the first half, the second half are the input labels, i.e. its structure is:
[-100, ... -100, <1024>, ... <1536>]
where <1024> denotes the input_token at this location that was not modified and the loss is calculated over.
The model allows you to calculate the loss yourself if you pass no labels.
If you set the reduction of the loss to 'none'
you get the loss for each individual token.
# Note this based on the original code
outputs = model(encodings.input_ids, labels=None)
logits = outputs.logits
labels = target_ids.to(logits.device)
# Shift that positions match
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction='mean')
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))