Search code examples
pythonpytorchhuggingface-transformershuggingfacehuggingface-trainer

Huggingface SFT for completion only not working


I have a project where I am trying to finetune Llama-2-7b on a dataset for Parameter extraction, which is linked here: <GalaktischeGurke/parameter_extraction_1500_mail_contract_invoice>. The problem with the dataset is that the context for a response is very big, meaning that training on the entire dataset with context, not only on the response results in a huge loss of performance. To fix this issue, I wanted to use SFT_trainer together with the DataCollatorForCompletionOnlyLM, which allows finetuning only for response. Now, before adjusting my training loop, I wanted to try the examples given here: https://huggingface.co/docs/trl/main/en/sft_trainer. Specifically, I used this code from the page:

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
output_dir = "./results"

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

instruction_template = "### Human:"
response_template = "### Assistant:"
collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template, response_template=response_template, tokenizer=tokenizer, mlm=False)

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    dataset_text_field="text",
    data_collator=collator,
)

trainer.train() 


import os
output_dir = os.path.join(output_dir, "final_checkpoint")
trainer.model.save_pretrained(output_dir)

The training loop did not crash, but it never seemed to train at all - There was no train/loss curve on wandb and the model saved didnt seem to have changed.

These are the things I tried: -Using the other code with preformat function -setting packing=False on the trainer -implementing it with my own loop, which yielded the same results -trying to find documentation on the collator, however it is not in the official docs at https://huggingface.co/docs/transformers/main_classes/data_collator

Does anyone know what the issue is here?


Solution

  • I have a similar issue. I think you're forgetting to add formatting_func function. Also, by default setting dataset_text_field overrides the use of the collator, so try without that argument.

    Here's how I call it. It runs and stores things to wandb, but my problem is my loss is always NaN. Lemme know if you found the issue!

    trainer = SFTTrainer(
        model,
        train_dataset=vanilla_data_set,
        eval_dataset=vanilla_data_set,
        args=training_args,
        # dataset_text_field="gpt-4",
        # torch_dtype=torch.bfloat16,
        peft_config=peft_config,
        max_seq_length=512,
        formatting_func=formatting_prompts_func,
        data_collator=collator
    )