Search code examples
pythontensorflowquantization

Use tensor's value to modify another tensor at run-time


In TensorFlow I'd like to apply one of the "fake" quantization functions to one of my tensors. Concretely I'm interested in using this function:

fake_quant_with_min_max_args(
    inputs,
    min=-6,
    max=6,
    num_bits=8,
    narrow_range=False,
    name=None
    )

It works just fine out of the box. Ideally I would like to adjust the min and max arguments depending on the input tensor. Concretely, I want to use the 99.7% rule to define that range. In other words, I want to use the range of values that, if representing the input tensor as a 1-dimensional vector, 99.7% of its elements will lie between the [mean-3*std, mean + 3*std] range.

For this purpose I do the following:

def smartFakeQuantization(tensor):
    # Convert the input tensor to a 1-d tensor
    t_1d_data = tf.reshape(tensor,[tf.size(tensor), 1])
    # get the moments of that tensor. Now mean and var have shape (1,)
    mean, var = tf.nn.moments(t_1d_data, axes=[0])
    # get a tensor containing the std
    std = tf.sqrt(var)

    < some code to get the values of those tensors at run-time>

    clip_range = np.round([mean_val - 3*std_val, mean_val + 3*stdstd_val], decimals=3)
    return tf.fake_quant_with_min_max_args(tensor, min=clip_range[0], max=clip_range[1])

I know that I could evaluate any tensor in the graph by doing: myTensor.eval() or mySession.run(myTensor) but if I add those kind of lines in side my function above it will crash when executing the graph. I'll get an error of the form:

tensor <...> has been marked as not fetchable.

Probably steps I'm following are not the correct ones for the "graph" nature of TensorFlow. Any ideas how this could be done? Summarising, I want to use the value of a tensor at run-time to modify another tensor. I'd say this problem is more complex than what can be done with tf.cond().


Solution

  • I don't think there is an easy way of doing what you want. The min and max arguments to fake_quant_with_min_max_args are converted to operation attributes and used in underlying kernel's construction. They cannot be changed at runtime. There are some (seemingly not part of public API) ops (see LastValueQuantize and MovingAvgQuantize) that adjust their intervals depending on the data they see, but they don't do quite what you want.

    You can write your own custom op or if you believe this is something generally valuable, file a feature request on github.