Search code examples
pythontensorflowtensorflow2.0tensorflow-litequantization

Tensorflow quantization process in detail - Anyone don't talk about this in detail


I am checking how the quantization process works in tensorflow lite now. However, the exact process is not explained anywhere. (They are just explaining the code for quantization using TFLite.)

When doing integer quantization, we know that the int8 quantization process for linear operations is done as follows.

According to https://www.tensorflow.org/lite/performance/quantization_spec, for Y = WX + b, (Y : out, W : weight, X : input, b : bias) using r = S(q-z) relationship (r: real value, S: scale-factor, q: quantized value, z: zero-point), it can be seen that the following equation can be written. q_y = M(q_w * q_x - Z_x * q_w + q_b) + Z_y where M = S_w * S_x / S_y.

And, according to https://arxiv.org/abs/1712.05877, the floating point number M can be approximated M0 * 2 ^(-n) where M0 is int32 fixed point number.

So, let's talk about number of bits in the quantization process. (inference case)

  1. q_w * q_x is int32 (actually it is depending on the tensor size of W and X, but just assumed)
  2. (- Z_x * q_w + q_b ) is int32, and it is known value (pre-computed)
  3. M0 is int32 (fixed point number)
  4. Z_y is int32 (according to TFLite converted model.)
  5. q_y should be int8

My question is here. q_w * q_x - Z_x * q_w + q_b is an int32 after all. Since M0 is int32, M(q_w * q_x - Z_x * q_w + q_b) = M0(q_w * q_x - Z_x * q_w + q_b) * 2^(-n) You can see that int32 is multiplied to become int64. (Thereafter, the routing bit-shift by n.) It's still 64-bit though. How we can add int32 Z_y to it? So how can we say that q_y is 8-bit? What is the role of M ?

Thank you

I am expecting to know the quatization process in detaill


Solution

  • One thing to keep in mind is that the tensor quantization scales are evaluated to make sure that:

    • the quantized values fit in 8-bit,
    • the multiplication of these values by the scales covers the range of the original float tensor.

    With that in mind, the output scale makes sure that the final result fits in 8-bit: once you have applied all operations in the correct order, you can just cast the final result.

    Regarding the M quantity, you can see it as the reciprocal of the scale required to downscale the int32 output of the previous operation to 8-bit (downscaling is an operation very similar to quantization applied on integer values).

    M being expressed as a fixed-point number (M0.2^-n), the downscale operation is composed of:

    • a multiplication by M0 producing a new fixed-point number with the same implicit exponent n,
    • a right-shift operation reducing the actual bitwidth by n and producing an integer (or equivalently a fixed-point number with a zero exponent).

    The scales have precisely been evaluated to make sure that for "typical" inputs, i.e. for inputs similar to those used for the calibration, that final integer fits in 8-bit. If it doesn't, then it is clipped to the int8/uint8 boundaries.

    Note: there is a chance of overflow when applying the multiplication. This can be mitigated by using an intermediate 64-bit accumulator, or by simply reducing the bitwidth of M0 (8-bit is most of the time enough).