Search code examples
pythonhuggingfacewandbaccelerate

Get accelerate package to log test results with huggingface Trainer


I am fine-tuning a T5 model on a specific dataset and my code looks like this:

accelerator = Accelerator(log_with='wandb')
tokenizer = T5Tokenizer.from_pretrained('t5-base')
model = T5ForConditionalGeneration.from_pretrained('t5-base')

accelerator.init_trackers(
project_name='myProject',
config={
# My configs
    }
)
# Then I do some preparations towards the fine-tuning

trainer_arguments = transformers.Seq2SeqTrainingArguments(
# Here I pass many arguments
)
trainer = transformers.Seq2SeqTrainer(
# Here I pass the arguments along side other needed arguments
)

# THEN FINALLY I TRAIN, EVALUATE AND TEST LIKE SO:

trainer.train()
trainer.evaluate( #evaluation parameters# )
trainer.predict( #test arguments# )

Now my main issue, when I check the wandb site for my project, I only see logging for the trainer.train() phase but not the trainer.evaluate() or trainer.predict() phases.

I've scoured the web trying to find a solution but could not find any.

How do I get wandb/accelerate to log all of my phases?
Thanks!

For the full code, you can see it here: https://github.com/zbambergerNLP/principled-pre-training/blob/master/fine_tune_t5.py


Solution

  • Unfortunately, Evaluation and Prediction metrics are not logged automatically like they do for Training on wandb. But there are ways to push them on wandb.

    Solution 01

    You can log evaluation and prediction metrics manually, after each phases:

    # After evaluation
    eval_metrics = trainer.evaluate()
    wandb.log({"evaluation": eval_metrics})
    
    # After prediction
    predictions = trainer.predict(test_dataset)
    wandb.log({"predictions": predictions.metrics})
    

    Solution 02

    You can also set a callback that will log your metrics automatically, after evaluation and prediction:

    from transformers import TrainerCallback
    
    class WandbLoggingCallback(TrainerCallback):
        def on_evaluate(self, args, state, control, **kwargs):
            # Log evaluation metrics
            metrics = kwargs.get("metrics", {})
            wandb.log({"eval": metrics})
    
        def on_predict(self, args, state, control, **kwargs):
            # Log prediction metrics
            metrics = kwargs.get("metrics", {})
            wandb.log({"predictions": metrics})
    
    # Use this callback in your trainer
    trainer = transformers.Seq2SeqTrainer(
        # Your arguments
        callbacks=[WandbLoggingCallback],
        # Other needed arguments
    )
    

    Here's a simple colab notebook that I borrowed from HuggingFace and modified it with ways to push evaluation and prediction metrics after training. You will find both manual and automatics approach there.