Can anyone point to references where one can learn how to perform Quantization Aware Training (QAT) with tf.GradientTape
on TensorFlow 2?
I only see this done with the tf.keras API. I do not use tf. keras
, I always build customized training with tf.GradientTape
provides more control over the training process. I now need to quantize a model but I only see references on how to do it using the tf. keras
API.
In the official examples here, they showed QAT training with model. fit
. Here is a demonstration of Quantization Aware Training using tf.GradientTape()
. But for complete reference, let's do both here.
Base model training. This is directly from the official doc. For more details, please check there.
import os
import tensorflow as tf
from tensorflow import keras
import tensorflow_model_optimization as tfmot
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
# Define the model architecture.
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28)),
keras.layers.Reshape(target_shape=(28, 28, 1)),
keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(10)
])
# Train the digit classification model
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.summary()
model.fit(
train_images,
train_labels,
epochs=1,
validation_split=0.1,
)
10ms/step - loss: 0.5411 - accuracy: 0.8507 - val_loss: 0.1142 - val_accuracy: 0.9705
<tensorflow.python.keras.callbacks.History at 0x7f9ee970ab90>
.fit
.Now, performing QAT over the base model.
# -----------------------
# ------------- Quantization Aware Training -------------
import tensorflow_model_optimization as tfmot
quantize_model = tfmot.quantization.keras.quantize_model
# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)
# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
q_aware_model.summary()
train_images_subset = train_images[0:1000]
train_labels_subset = train_labels[0:1000]
q_aware_model.fit(train_images_subset, train_labels_subset,
batch_size=500, epochs=1, validation_split=0.1)
356ms/step - loss: 0.1431 - accuracy: 0.9629 - val_loss: 0.1626 - val_accuracy: 0.9500
<tensorflow.python.keras.callbacks.History at 0x7f9edf0aef90>
Checking performance
_, baseline_model_accuracy = model.evaluate(
test_images, test_labels, verbose=0)
_, q_aware_model_accuracy = q_aware_model.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)
Baseline test accuracy: 0.9660999774932861
Quant test accuracy: 0.9660000205039978
tf.GradientTape()
.Here is the QAT training part on the base model. Note we can also perform custom training over the base model.
batch_size = 500
train_dataset = tf.data.Dataset.from_tensor_slices((train_images_subset,
train_labels_subset))
train_dataset = train_dataset.batch(batch_size=batch_size,
drop_remainder=False)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()
for epoch in range(1):
for x, y in train_dataset:
with tf.GradientTape() as tape:
preds = q_aware_model(x, training=True)
loss = loss_fn(y, preds)
grads = tape.gradient(loss, q_aware_model.trainable_variables)
optimizer.apply_gradients(zip(grads, q_aware_model.trainable_variables))
_, baseline_model_accuracy = model.evaluate(
test_images, test_labels, verbose=0)
_, q_aware_model_accuracy = q_aware_model.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)
Baseline test accuracy: 0.9660999774932861
Quant test accuracy: 0.9645000100135803