Search code examples
pythontensorflowroundingquantization

What is the round through function in QKeras/Python?


I was going through the QKeras implementation of the 'quantized_bits' class. Inside the call function I came across a '_round_through' function.

Here's how the function was being called:

if unsigned_bits > 0:
      p = x * m / m_i
      xq = m_i * tf.keras.backend.clip(
          _round_through(p, self.use_stochastic_rounding, precision=1.0),
          self.keep_negative  * (-m + self.symmetric), m - 1) / m

I tried running the code in Python but it is apparently not a built-in function. I tried searching for the function but I only got results about the 'round()' function in Python.

So my question is: what does this function do? And what module is it a part of?

Link to Code: https://github.com/google/qkeras/blob/master/qkeras/quantizers.py#L489

Any help would be appreciated!


Solution

  • If you click on the function in GitHub, it will tell you where this function came from. In this case, the function is defined at line 271.

    def _round_through(x, use_stochastic_rounding=False, precision=0.5):
      """Rounds x but using straight through estimator.
      We use the trick from [Sergey Ioffe](http://stackoverflow.com/a/36480182).
      Straight through estimator is a biased estimator for the rounding
      operation defined by Hinton"s Coursera Lecture 9c where dL/dx is made
      equal to dL/dy for y = f(x) during gradient computation, where f(x) is
      a non-derivable function. In that case, we assume df/dx = 1 in:
      dL   dL df   dL
      -- = -- -- = --
      dx   df dx   dy
      (https://www.youtube.com/watch?v=LN0xtUuJsEI&list=PLoRl3Ht4JOcdU872GhiYWf6jwrk_SNhz9&index=41)
      Arguments:
        x: tensor to perform round operation with straight through gradient.
        use_stochastic_rounding: if true, we perform stochastic rounding.
        precision: by default we will use 0.5 as precision, but that can overriden
          by the user.
      Returns:
        Rounded tensor.
      """
      if use_stochastic_rounding:
        output = tf_utils.smart_cond(
            K.learning_phase(),
            lambda: x + tf.stop_gradient(-x + stochastic_round(x, precision)),
            lambda: x + tf.stop_gradient(-x + tf.round(x)))
      else:
        output = x + tf.stop_gradient(-x + tf.round(x))
      return output