Search code examples
pythontensorflowpytorchhuggingface-transformers

Is the default `Trainer` class in HuggingFace transformers using PyTorch or TensorFlow under the hood?


Question

According to the official documentation, the Trainer class "provides an API for feature-complete training in PyTorch for most standard use cases".

However, when I try to actually use Trainer in practice, I get the following error message that seems to suggest that TensorFlow is currently being used under the hood.

tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.

So which one is it? Does the HuggingFace transformers library use PyTorch or TensorFlow for their internal implementation of Trainer? And is it possible to switch to only using PyTorch? I can't seem to find a relevant parameter in TrainingArguments.

Why does my script keep printing out TensorFlow related errors? Shouldn't Trainer be using PyTorch only?

Source code

from transformers import GPT2Tokenizer
from transformers import GPT2LMHeadModel
from transformers import TextDataset
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer
from transformers import TrainingArguments

import torch

# Load the GPT-2 tokenizer and LM head model
tokenizer    = GPT2Tokenizer.from_pretrained('gpt2')
lmhead_model = GPT2LMHeadModel.from_pretrained('gpt2')

# Load the training dataset and divide blocksize
train_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path='./datasets/tinyshakespeare.txt',
    block_size=64
)

# Create a data collator for preprocessing batches
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

# Defining the training arguments
training_args = TrainingArguments(
    output_dir='./models/tinyshakespeare', # output directory for checkpoints
    overwrite_output_dir=True,             # overwrite any existing content

    per_device_train_batch_size=4,         # sample batch size for training
    dataloader_num_workers=1,              # number of workers for dataloader
    max_steps=100,                         # maximum number of training steps
    save_steps=50,                         # after # steps checkpoints are saved
    save_total_limit=5,                    # maximum number of checkpoints to save

    prediction_loss_only=True,             # only compute loss during prediction
    learning_rate=3e-4,                    # learning rate
    fp16=False,                            # use 16-bit (mixed) precision

    optim='adamw_torch',                   # define the optimizer for training
    lr_scheduler_type='linear',            # define the learning rate scheduler

    logging_steps=5,                       # after # steps logs are printed
    report_to='none',                      # report to wandb, tensorboard, etc.
)

if __name__ == '__main__':
    torch.multiprocessing.freeze_support()

    trainer = Trainer(
        model=lmhead_model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
    )

    trainer.train()

Solution

  • It depends on how the model is trained and how you load the model. Most popular models on transformers supports both PyTorch and Tensorflow (and sometimes also JAX).

    from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
    from transformers import TFAutoModelForSeq2SeqLM
    
    model_name = "google/flan-t5-large"
    
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    
    # This would work if the model's backend is PyTorch.
    print(type(next(model.parameters())))
    
    
    tf_model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
    
    # The `model.parameters()` would not work for Tensorflow,
    # instead you can try `.summary()`
    tf_model.summary()
    

    [out]:

    <class 'torch.nn.parameter.Parameter'>
    
    Model: "tft5_for_conditional_generation"
    _________________________________________________________________
     Layer (type)                Output Shape              Param #   
    =================================================================
     shared (Embedding)          multiple                  32899072  
                                                                     
     encoder (TFT5MainLayer)     multiple                  341231104 
                                                                     
     decoder (TFT5MainLayer)     multiple                  441918976 
                                                                     
     lm_head (Dense)             multiple                  32899072  
                                                                     
    =================================================================
    Total params: 783,150,080
    Trainable params: 783,150,080
    Non-trainable params: 0
    _________________________________________________________________
    
    

    Maybe something like:

    
    def which_backend(model):
      try:
        model.parameters()
        return 'torch'
      except:
        try:
          model.summary()
          return 'tensorflow'
        except:
          return 'I have no idea... Maybe JAX?'
    
    

    Q: So if I use Trainer, it's PyTorch?

    A: Yes, most probably the model has PyTorch backend, and the training loop (optimizer, loss, etc.) uses PyTorch. But the Trainer() isn't the model, it's the wrapper object.

    Q: And if I want to use Trainer for Tensorflow backend models, I should use TFTrainer?

    Not really. In the latest version of transformers, the TFTrainer object is deprecated, see https://github.com/huggingface/transformers/pull/12706

    It is recommended that you use Keras' sklearn-style .fit() training if you are using a model with Tensorflow backend.

    Q: Why does my script keep printing out TensorFlow related errors? Shouldn't Trainer be using PyTorch only?

    Try checking your transformers version, most probably you are using an outdated version that uses some deprecated objects, e.g. TextDataset (see How to resolve "only integer tensors of a single element can be converted to an index" error when creating a Dataset to fine-tune GPT2 model?)

    In the later versions, most probably pip install transformers>=4.26.1, the Trainer shouldn't be activating TF warnings and using TFTrainer would have raised warnings to suggest users to use Keras instead.