Search code examples
pythonhuggingface-transformersgpt-2

Transformers cross-entropy loss masked label issue


I try to count cross-entropy loss for text, using gpt-2. Take the idea from this article https://huggingface.co/docs/transformers/perplexity

from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import torch

from tqdm import tqdm


model_id = "gpt2-large"
model = GPT2LMHeadModel.from_pretrained(model_id)
                                                                                          
tokenizer = GPT2TokenizerFast.from_pretrained(model_id, cache_dir='.')

encodings = tokenizer("She felt his demeanor was sweet and endearing.", return_tensors="pt")

max_length = model.config.n_positions
seq_len = encodings.input_ids.size(1)

target_ids = encodings.input_ids.clone()
#target_ids[:, :-seq_len] = -100 COMMENTED LINE                                                                                                               

with torch.no_grad():
    outputs = model(encodings.input_ids, labels=target_ids)
    print(outputs.loss.item())

Status of commented line(commented/uncommented) does not matter at all. I always get 4.352320194244385 as print output. According to documentation https://huggingface.co/docs/transformers/v4.35.2/en/model_doc/gpt2#transformers.GPT2LMHeadModel :

labels (torch.LongTensor of shape (batch_size, sequence_length), optional) — Labels for language modeling. Note that the labels are shifted inside the model, i.e. you can set labels = input_ids. Indices are selected in [-100, 0, ..., config.vocab_size - 1]. All labels set to -100 are ignored (masked), the loss is only computed for labels in [0, ..., config.vocab_size - 1]

Why do I get the same result?


Solution

  • Decoder-Only models, their inputs and targets

    Lets first start with how decoder-only models roughly work in huggingface. The main ingredients are input_ids, label_ids and the attention_mask. The latter are not relevant here.

    Lets say we have a input: "The answer is:" and we want it to learn to answer with "42". Translated to tokens "The answer is: 42" is [464, 3280, 318, 25, 5433] (Where ":" is 25 and " 42" 5433).

    For the learning objective we want to learn "42". The targets are therefore: label_ids=[-100, -100, -100, -100, 5433]. So the model will not learn that "The answer" is followed by "is:".


    Sidenotes:

    The decoder-only model expect input and output to have the same shape.
    In contrary, encoder-decoder models can have "The answer is:" as input and "42" as output.


    -100 is the ignore_index of the standard torch.nn.CrossEntropyLoss. "Ignored" is here a better term than "masked", the latter implies that the model "does not see" the said input, or that the original got replaced by a special "<masked>" token.



    Why do you get the same result & how to proceed further

    As noted by the other answer and comment: You do target_ids[:, :-seq_len] = -100 which in this case is target_ids[:, :-10] ALL BUT the 10 LAST elements. As target_ids have length 10 nothing is changed and you get the same result.

    :-seq_len is correct here to use but you need to think a bit more ahead and what you actually want, therefore my introduction.

    Going further: iterating over a dataset with a context

    In your linked post the interesting part happens in the second iteration! In the first nothing is ignored.

    Roughly what happens is:

    -- First iteration --

    max_length = 1024 
    stride = 512
    
    end_loc = 1024
    inputs_ids = tokens[0 : 1024] 
    target_ids = input_ids.clone()
    target_ids[:-1024] = -100 # <--- nothing is changed here!
    # this is equivalent to target_ids[:0] = -100 or target_ids[:-len(target_ids)] = -100
    
    assert target_ids == input_ids # only for the first iteration.
    
    trg_len = 1024  # NOTE 1024 not 512
    prev_end_loc = 1024
    

    The loss is calculated over ALL 1024 tokens.

    -- Second+ iteration --

    begin_loc = 512
    end_loc = 1536
    trg_len = 1536 - 1024 = 512 # NOTE here it is 512
    
    input_ids = tokens[512 : 1536] #NOTE: tokens 512-1024 we have SEEN already 
    target_ids = tokens[512 : 1536].clone() # tensor of length 1024
    target_ids = target_ids[:-512] = -100   # the already SEEN tokens are ignored
    

    Loss is calculated over the latter 512 tokens from all 1024 only.

    So, only from the second iteration onward target_ids is a tensor which is filled by -100 for the first half, the second half are the input labels, i.e. its structure is: [-100, ... -100, <1024>, ... <1536>] where <1024> denotes the input_token at this location that was not modified and the loss is calculated over.


    Alternative calculate the loss yourself:

    The model allows you to calculate the loss yourself if you pass no labels. If you set the reduction of the loss to 'none' you get the loss for each individual token.

    # Note this based on the original code
    outputs = model(encodings.input_ids, labels=None)
    
    logits = outputs.logits
    labels = target_ids.to(logits.device)
    
    # Shift that positions match
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    
    # Flatten the tokens
    loss_fct = CrossEntropyLoss(reduction='mean')
    loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))