Search code examples
machine-learningpytorchhuggingface-transformershuggingface

Is it possible to save the training/validation loss in a list during training in HuggingFace?


I'm currently training my model using the HuggingFace Trainer class:

from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    output_dir="codeparrot-ds",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    evaluation_strategy="steps",
    eval_steps=5_000,
    logging_steps=5_000,
    gradient_accumulation_steps=8,
    num_train_epochs=1,
    weight_decay=0.1,
    warmup_steps=1_000,
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    save_steps=5_000,
    fp16=True,
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["valid"],
)

This prints the loss during training, but I can't figure out how to save it so that I can plot it later. Note that I need both the training and validation losses during training.


Solution

  • One way of proceeding might be the following: you can access training and evaluation losses via the trainer.state.log_history object after training. An example below (accuracy and f1 might be ignored as they derive from the specific compute_metrics function passed as parameter to the trainer instance):

    enter image description here

    It is a list of dicts which contains some logged values per logged step; among the keys of the different dictionaries you should find 'loss' and 'eval_loss', whose values you might retrieve as follows (analogously for validation losses).

    train_loss = []
    for elem in trainer.state.log_history:
        if 'loss' in elem.keys():
            train_loss.append(elem['loss'])
    

    The loss is computed via the .compute_loss() method of the Trainer class, which you might override for custom behaviour as described at https://huggingface.co/docs/transformers/v4.27.1/en/main_classes/trainer#trainer.