I load a huggingface-transformers float32 model, cast it to float16, and save it. How can I load it as float16?
Example:
# pip install transformers
from transformers import AutoModelForTokenClassification, AutoTokenizer
# Load model
model_path = 'huawei-noah/TinyBERT_General_4L_312D'
model = AutoModelForTokenClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Convert the model to FP16
model.half()
# Check model dtype
def print_model_layer_dtype(model):
print('\nModel dtypes:')
for name, param in model.named_parameters():
print(f"Parameter: {name}, Data type: {param.dtype}")
print_model_layer_dtype(model)
save_directory = 'temp_model_SE'
model.save_pretrained(save_directory)
model2 = AutoModelForTokenClassification.from_pretrained(save_directory, local_files_only=True)
print('\n\n##################')
print(model2)
print_model_layer_dtype(model2)
In this example, model2
loads as a float32
model (as shown by print_model_layer_dtype(model2)
), even though model2
was saved as float16 (as shown in config.json
). What is the proper way to load it as float16?
Tested with transformers==4.36.2
and Python 3.11.7 on Windows 10.
Use torch_dtype='auto'
in from_pretrained()
. Example:
model2 = AutoModelForTokenClassification.from_pretrained(save_directory,
local_files_only=True,
torch_dtype='auto')
Full example:
# pip install transformers
from transformers import AutoModelForTokenClassification, AutoTokenizer
import torch
# Load model
model_path = 'huawei-noah/TinyBERT_General_4L_312D'
model = AutoModelForTokenClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Convert the model to FP16
model.half()
# Check model dtype
def print_model_layer_dtype(model):
print('\nModel dtypes:')
for name, param in model.named_parameters():
print(f"Parameter: {name}, Data type: {param.dtype}")
print_model_layer_dtype(model)
save_directory = 'temp_model_SE'
model.save_pretrained(save_directory)
model2 = AutoModelForTokenClassification.from_pretrained(save_directory, local_files_only=True, torch_dtype='auto')
print('\n\n##################')
print(model2)
print_model_layer_dtype(model2)
It'll load model2 as torch.float16
.