Search code examples

Ways to print the value of labels in 'compute_loss' of in transformers to the terminal during training process

I would like to print the labels to the terminal by adding codes inside compute loss of in transformers.

I tried adding both print and, but neither works.

def compute_loss(self, model, inputs, return_outputs=False):
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
            labels = None"{labels}")  # here~!
        outputs = model(**inputs)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            if is_peft_available() and isinstance(model, PeftModel):
                model_name = unwrap_model(model.base_model)._get_name()
                model_name = unwrap_model(model)._get_name()
            if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
                loss = self.label_smoother(outputs, labels)
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss

The terminal shows:

(qwen) seamoon@gpu07:~/code/Qwen$ bash finetune/ 
[2024-01-06 17:55:53,192] [WARNING] 
[2024-01-06 17:55:53,192] [WARNING] *****************************************
[2024-01-06 17:55:53,192] [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[2024-01-06 17:55:53,192] [WARNING] *****************************************
[2024-01-06 17:55:54,784] [INFO] [] Setting ds_accelerator to cuda (auto detect)
[2024-01-06 17:55:54,853] [INFO] [] Setting ds_accelerator to cuda (auto detect)
The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".
Try importing flash-attention for faster inference...
Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency
Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency
Warning: import flash_attn fail, please install FlashAttention to get higher efficiency
The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".
Try importing flash-attention for faster inference...
Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency
Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency
Warning: import flash_attn fail, please install FlashAttention to get higher efficiency
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.41it/s]
Loading data...
Formatting inputs...Skip in lazy mode
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.10it/s]
  0%|                                                                                                                                                                   | 0/20 [00:00<?, ?it/s]/home/seamoon/anaconda3/envs/qwen/lib/python3.10/site-packages/torch/utils/ UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
/home/seamoon/anaconda3/envs/qwen/lib/python3.10/site-packages/torch/utils/ UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
{'loss': 1.1746, 'learning_rate': 1e-05, 'epoch': 0.24}                                                                                                                                        
{'loss': 1.2259, 'learning_rate': 9.931806517013612e-06, 'epoch': 0.48}                                                                                                                        
{'loss': 0.1804, 'learning_rate': 9.729086208503174e-06, 'epoch': 0.73}                                                                                                                        
{'loss': 0.1697, 'learning_rate': 9.397368756032445e-06, 'epoch': 0.97}                                                                                                                        
{'loss': 0.1526, 'learning_rate': 8.94570254698197e-06, 'epoch': 1.21}                                                                                                                         
{'loss': 0.1521, 'learning_rate': 8.386407858128707e-06, 'epoch': 1.45}                                                                                                                        
{'loss': 0.1554, 'learning_rate': 7.734740790612137e-06, 'epoch': 1.7}                                                                                                                         
{'loss': 0.1513, 'learning_rate': 7.008477123264849e-06, 'epoch': 1.94}                                                                                                                        
{'loss': 0.1523, 'learning_rate': 6.227427435703997e-06, 'epoch': 2.18}                                                                                                                        
{'loss': 0.1326, 'learning_rate': 5.412896727361663e-06, 'epoch': 2.42}                                                                                                                        
{'loss': 0.1484, 'learning_rate': 4.587103272638339e-06, 'epoch': 2.67}                                                                                                                        
{'loss': 0.1335, 'learning_rate': 3.7725725642960047e-06, 'epoch': 2.91}                                                                                                                       
{'loss': 0.1443, 'learning_rate': 2.991522876735154e-06, 'epoch': 3.15}                                                                                                                        
{'loss': 0.133, 'learning_rate': 2.265259209387867e-06, 'epoch': 3.39}                                                                                                                         
{'loss': 0.1391, 'learning_rate': 1.6135921418712959e-06, 'epoch': 3.64}                                                                                                                       
{'loss': 0.1299, 'learning_rate': 1.0542974530180327e-06, 'epoch': 3.88}                                                                                                                       
{'loss': 0.1414, 'learning_rate': 6.026312439675553e-07, 'epoch': 4.12}                                                                                                                        
{'loss': 0.1366, 'learning_rate': 2.7091379149682683e-07, 'epoch': 4.36}                                                                                                                       
{'loss': 0.1349, 'learning_rate': 6.819348298638839e-08, 'epoch': 4.61}                                                                                                                        
{'loss': 0.1322, 'learning_rate': 0.0, 'epoch': 4.85}                                                                                                                                          
{'train_runtime': 99.7731, 'train_samples_per_second': 6.615, 'train_steps_per_second': 0.2, 'train_loss': 0.2510054260492325, 'epoch': 4.85}                                                  
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [01:39<00:00,  4.98s/it] path: /home/seamoon/anaconda3/lib/python3.11/site-packages/transformers/

Check from here

Check from here

My question is, is there any way of printing out the value of labels?

Any help would be appreciated. :)


  • You can subclass Trainer and override compute_loss:

    from transformers import Trainer
    class CustomTrainer(Trainer):
        def compute_loss(self, model, inputs, return_outputs=False):
            # get the labels
            labels = inputs.get("labels")
            print(labels) # you can use loging
            # call the original compute_loss method
            return super().compute_loss(model, inputs, return_outputs=return_outputs)

    Then you can create trainer using:

        # Start trainner
        trainer = CustomTrainer(
            model=model, tokenizer=tokenizer, args=training_args, **data_module