I am trying to find a TensorFlow equivalent of np.quantile()
. I have found tfp.stats.quantiles()
(tfp
stands for TensorFlow Probability). However, its constructs are a bit different from that of np.quantile()
.
Consider the following example:
import tensorflow_probability as tfp
import tensorflow as tf
import numpy as np
inputs = tf.random.normal((1, 4096, 4))
print("NumPy")
print(np.quantile(inputs.numpy(), q=0.9, axis=1, keepdims=False))
I am not sure from the TFP docs how I could write the above using tfp.stats.quantile()
. I tried checking out the source code of both methods, but it didn't help.
Let me try to be more helpful here than I was on GitHub.
There is a difference in behavior between np.quantile
and tfp.stats.quantiles
. The key difference here is that numpy.quantile
will
Compute the q-th quantile of the data along the specified axis.
where q
is the
Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive.
and tfp.stats.quantiles
Given a vector
x
of samples, this function estimates the cut points by returningnum_quantiles + 1
cut points
So you need to tell tfp.stats.quantiles
how many quantiles you want and then select out the q
th quantile. If it isn't clear how to do this just from the API, if you look at the source for tfp.stats.quantiles
(for v0.19.0
) we can see that it shows us how we can get a similar return structure as NumPy.
For completeness, setting up a virtual environment with
$ cat requirements.txt
numpy==1.24.2
tensorflow==2.11.0
tensorflow-probability==0.19.0
allows us to run
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
inputs = tf.random.normal((1, 4096, 4), dtype=tf.float64)
q = 0.9
numpy_quantiles = np.quantile(inputs.numpy(), q=q, axis=1, keepdims=False)
tfp_quantiles = tfp.stats.quantiles(
inputs, num_quantiles=100, axis=1, interpolation="linear"
)[int(q * 100)]
assert np.allclose(numpy_quantiles, tfp_quantiles.numpy())
print(f"{numpy_quantiles=}")
# numpy_quantiles=array([[1.31727661, 1.2699167 , 1.28735237, 1.27137588]])
print(f"{tfp_quantiles=}")
# tfp_quantiles=<tf.Tensor: shape=(1, 4), dtype=float64, numpy=array([[1.31727661, 1.2699167 , 1.28735237, 1.27137588]])>