Search code examples
tensorflowkerasquantization

TensorFlow 2 Quantization Aware Training (QAT) with tf.GradientTape


Can anyone point to references where one can learn how to perform Quantization Aware Training (QAT) with tf.GradientTape on TensorFlow 2?

I only see this done with the tf.keras API. I do not use tf. keras, I always build customized training with tf.GradientTape provides more control over the training process. I now need to quantize a model but I only see references on how to do it using the tf. keras API.


Solution

  • In the official examples here, they showed QAT training with model. fit. Here is a demonstration of Quantization Aware Training using tf.GradientTape(). But for complete reference, let's do both here.


    Base model training. This is directly from the official doc. For more details, please check there.

    import os
    import tensorflow as tf
    from tensorflow import keras
    import tensorflow_model_optimization as tfmot
    
    # Load MNIST dataset
    mnist = keras.datasets.mnist
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
    
    # Normalize the input image so that each pixel value is between 0 to 1.
    train_images = train_images / 255.0
    test_images = test_images / 255.0
    
    # Define the model architecture.
    model = keras.Sequential([
      keras.layers.InputLayer(input_shape=(28, 28)),
      keras.layers.Reshape(target_shape=(28, 28, 1)),
      keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
      keras.layers.MaxPooling2D(pool_size=(2, 2)),
      keras.layers.Flatten(),
      keras.layers.Dense(10)
    ])
    
    # Train the digit classification model
    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])
    model.summary()
    model.fit(
      train_images,
      train_labels,
      epochs=1,
      validation_split=0.1,
    )
    
    10ms/step - loss: 0.5411 - accuracy: 0.8507 - val_loss: 0.1142 - val_accuracy: 0.9705
    <tensorflow.python.keras.callbacks.History at 0x7f9ee970ab90>
    

    QAT .fit.

    Now, performing QAT over the base model.

    # -----------------------
    # ------------- Quantization Aware Training -------------
    import tensorflow_model_optimization as tfmot
    
    quantize_model = tfmot.quantization.keras.quantize_model
    # q_aware stands for for quantization aware.
    q_aware_model = quantize_model(model)
    
    # `quantize_model` requires a recompile.
    q_aware_model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])
    
    q_aware_model.summary()
    train_images_subset = train_images[0:1000] 
    train_labels_subset = train_labels[0:1000]
    q_aware_model.fit(train_images_subset, train_labels_subset,
                      batch_size=500, epochs=1, validation_split=0.1)
    
    
    356ms/step - loss: 0.1431 - accuracy: 0.9629 - val_loss: 0.1626 - val_accuracy: 0.9500
    <tensorflow.python.keras.callbacks.History at 0x7f9edf0aef90>
    

    Checking performance

    _, baseline_model_accuracy = model.evaluate(
        test_images, test_labels, verbose=0)
    
    _, q_aware_model_accuracy = q_aware_model.evaluate(
       test_images, test_labels, verbose=0)
    
    print('Baseline test accuracy:', baseline_model_accuracy)
    print('Quant test accuracy:', q_aware_model_accuracy)
    
    Baseline test accuracy: 0.9660999774932861
    Quant test accuracy: 0.9660000205039978
    

    QAT tf.GradientTape().

    Here is the QAT training part on the base model. Note we can also perform custom training over the base model.

    batch_size = 500
    
    train_dataset = tf.data.Dataset.from_tensor_slices((train_images_subset,
                                                         train_labels_subset))
    train_dataset = train_dataset.batch(batch_size=batch_size, 
                                        drop_remainder=False)
    
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    optimizer = tf.keras.optimizers.Adam()
    
    for epoch in range(1):
        for x, y in train_dataset:
            with tf.GradientTape() as tape:
                preds = q_aware_model(x, training=True)
                loss = loss_fn(y, preds)
            grads = tape.gradient(loss, q_aware_model.trainable_variables)
            optimizer.apply_gradients(zip(grads, q_aware_model.trainable_variables))
            
    _, baseline_model_accuracy = model.evaluate(
        test_images, test_labels, verbose=0)
    
    _, q_aware_model_accuracy = q_aware_model.evaluate(
       test_images, test_labels, verbose=0)
    
    print('Baseline test accuracy:', baseline_model_accuracy)
    print('Quant test accuracy:', q_aware_model_accuracy)
    
    Baseline test accuracy: 0.9660999774932861
    Quant test accuracy: 0.9645000100135803