Search code examples
pythonnanhuggingface-transformerspytorch-lightningloss

Loss becomes Nan after attention_mask is added to the model while fine-tuning gemma2


I was trying to fine-tune gemma2 2b model on my own dataset for sequence classification tasks. But when I was testing the model, I found that after I plugged in the attention_mask to the model, the loss becomes Nan.

Here is my code

from peft import get_peft_model, LoraConfig, TaskType
from transformers import (AutoTokenizer,Gemma2ForSequenceClassification,DataCollatorWithPadding)
import torch

temp = Gemma2ForSequenceClassification.from_pretrained(
"gemma2b",device_map="auto",torch_dtype=torch.bfloat16)

peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=['down_proj','o_proj','k_proj','q_proj','gate_proj','up_proj','v_proj']
)

model = get_peft_model(temp, peft_config)
model.print_trainable_parameters()
tokenizer = AutoTokenizer.from_pretrained("gemma2b")

label=torch.tensor([0]).to('cuda')

raw_t=tokenizer(['I like it too'],return_tensors='pt',padding='max_length',max_length=10).to('cuda')
 
print(model(input_ids=raw_t.input_ids ,attention_mask=raw_t.attention_mask ,labels=label))

Ane here is the output:

SequenceClassifierOutputWithPast(loss=tensor(nan, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>), logits=tensor([[nan, nan]], device='cuda:0', dtype=torch.bfloat16,grad_fn=<IndexBackward0>), past_key_values=None, hidden_states=None, attentions=None)

If I don't plug in the attention_mask, the loss looks fine.

Besides, I noticed that if I don't pad the input to the max_length(attention_mask is all 1s), the problem won't occur.

And if I change the precision to float16, the loss seems normal too.

Could anyone help me solve the problem?


Solution

  • This is the problem of the default attention. Applying flash attention could solve this:

    https://github.com/huggingface/transformers/issues/32390