Assuming a BERT model is trained on fp16 and then fine-tuned on fp32 for a specific task, would this result in an increase or decrease in accuracy?
It can take less memory on GPU, training time will be reduced.
What you're referring to is called mixed-precision training, and it's basically training the model with low-precision floating-point numbers (e.g., fp16) for most of the layers and using high-precision numbers (e.g., fp32) only for certain layers that require more accuracy. Fine-tuning a low-precision model with high precision on accuracy can vary depending on the specific model and task. In some cases, fine-tuning can result in an increase in accuracy, while in other cases it may not. here is a quick and dirty script to try for mixed precision, to see if it is worth trying as a temperature check:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_text as text
import numpy as np
import os
import json
import math
import time
# Set mixed precision policy
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)
# Load BERT model and tokenizer
bert_model_name = 'bert-base-cased'
bert_dir = f'bert_models/{bert_model_name}'
tokenizer = BertTokenizer.from_pretrained(bert_dir)
bert_model = TFBertForSequenceClassification.from_pretrained(bert_dir)
# Define optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5)
# Define loss function
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Define metrics
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
# Define batch size
batch_size = 32
# Load training data
train_data = tfds.load('glue/mrpc', split='train', shuffle_files=True)
train_data = train_data.batch(batch_size)
# Fine-tune BERT model
epochs = 5
for epoch in range(epochs):
start_time = time.time()
metric.reset_states()
for batch_idx, data in enumerate(train_data):
input_ids = data['input_ids']
attention_mask = data['attention_mask']
token_type_ids = data['token_type_ids']
labels = data['label']
# Cast input data to mixed precision
input_ids = tf.cast(input_ids, tf.float16)
attention_mask = tf.cast(attention_mask, tf.float16)
token_type_ids = tf.cast(token_type_ids, tf.float16)
labels = tf.cast(labels, tf.float16)
with tf.GradientTape() as tape:
outputs = bert_model(input_ids, attention_mask, token_type_ids)
loss_value = loss(labels, outputs.logits)
grads = tape.gradient(loss_value, bert_model.trainable_weights)
optimizer.apply_gradients(zip(grads, bert_model.trainable_weights))
metric.update_state(labels, outputs.logits)
epoch_time = time.time() - start_time
print(f'Epoch {epoch + 1}/{epochs}, Loss: {loss_value:.4f}, Accuracy: {metric.result().numpy():.4f}, Time: {epoch_time:.2f}s')