According to the official documentation, the Trainer
class "provides an API for feature-complete training in PyTorch for most standard use cases".
However, when I try to actually use Trainer
in practice, I get the following error message that seems to suggest that TensorFlow is currently being used under the hood.
tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
So which one is it? Does the HuggingFace transformers library use PyTorch or TensorFlow for their internal implementation of Trainer
? And is it possible to switch to only using PyTorch? I can't seem to find a relevant parameter in TrainingArguments
.
Why does my script keep printing out TensorFlow related errors? Shouldn't Trainer
be using PyTorch only?
from transformers import GPT2Tokenizer
from transformers import GPT2LMHeadModel
from transformers import TextDataset
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer
from transformers import TrainingArguments
import torch
# Load the GPT-2 tokenizer and LM head model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
lmhead_model = GPT2LMHeadModel.from_pretrained('gpt2')
# Load the training dataset and divide blocksize
train_dataset = TextDataset(
tokenizer=tokenizer,
file_path='./datasets/tinyshakespeare.txt',
block_size=64
)
# Create a data collator for preprocessing batches
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# Defining the training arguments
training_args = TrainingArguments(
output_dir='./models/tinyshakespeare', # output directory for checkpoints
overwrite_output_dir=True, # overwrite any existing content
per_device_train_batch_size=4, # sample batch size for training
dataloader_num_workers=1, # number of workers for dataloader
max_steps=100, # maximum number of training steps
save_steps=50, # after # steps checkpoints are saved
save_total_limit=5, # maximum number of checkpoints to save
prediction_loss_only=True, # only compute loss during prediction
learning_rate=3e-4, # learning rate
fp16=False, # use 16-bit (mixed) precision
optim='adamw_torch', # define the optimizer for training
lr_scheduler_type='linear', # define the learning rate scheduler
logging_steps=5, # after # steps logs are printed
report_to='none', # report to wandb, tensorboard, etc.
)
if __name__ == '__main__':
torch.multiprocessing.freeze_support()
trainer = Trainer(
model=lmhead_model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
)
trainer.train()
It depends on how the model is trained and how you load the model. Most popular models on transformers
supports both PyTorch and Tensorflow (and sometimes also JAX).
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import TFAutoModelForSeq2SeqLM
model_name = "google/flan-t5-large"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# This would work if the model's backend is PyTorch.
print(type(next(model.parameters())))
tf_model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
# The `model.parameters()` would not work for Tensorflow,
# instead you can try `.summary()`
tf_model.summary()
[out]:
<class 'torch.nn.parameter.Parameter'>
Model: "tft5_for_conditional_generation"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
shared (Embedding) multiple 32899072
encoder (TFT5MainLayer) multiple 341231104
decoder (TFT5MainLayer) multiple 441918976
lm_head (Dense) multiple 32899072
=================================================================
Total params: 783,150,080
Trainable params: 783,150,080
Non-trainable params: 0
_________________________________________________________________
Maybe something like:
def which_backend(model):
try:
model.parameters()
return 'torch'
except:
try:
model.summary()
return 'tensorflow'
except:
return 'I have no idea... Maybe JAX?'
Trainer
, it's PyTorch?A: Yes, most probably the model has PyTorch backend, and the training loop (optimizer, loss, etc.) uses PyTorch. But the Trainer()
isn't the model, it's the wrapper object.
Trainer
for Tensorflow backend models, I should use TFTrainer
?Not really. In the latest version of transformers
, the TFTrainer
object is deprecated, see https://github.com/huggingface/transformers/pull/12706
It is recommended that you use Keras' sklearn-style .fit()
training if you are using a model with Tensorflow backend.
Try checking your transformers
version, most probably you are using an outdated version that uses some deprecated objects, e.g. TextDataset (see How to resolve "only integer tensors of a single element can be converted to an index" error when creating a Dataset to fine-tune GPT2 model?)
In the later versions, most probably pip install transformers>=4.26.1
, the Trainer shouldn't be activating TF warnings and using TFTrainer would have raised warnings to suggest users to use Keras instead.