Search code examples
pythonnumpytensorflowmedian

Tensorflow median value


How can I calculate the median value of a list in tensorflow? Like

node = tf.median(X)

X is the placeholder
In numpy, I can directly use np.median to get the median value. How can I use the numpy operation in tensorflow?


Solution

  • edit: This answer is outdated, use Lucas Venezian Povoa's solution instead. It is simpler and faster.

    You can calculate the median inside tensorflow using:

    def get_median(v):
        v = tf.reshape(v, [-1])
        mid = v.get_shape()[0]//2 + 1
        return tf.nn.top_k(v, mid).values[-1]
    

    If X is already a vector you can skip the reshaping.

    If you care about the median value being the mean of the two middle elements for vectors of even size, you should use this instead:

    def get_real_median(v):
        v = tf.reshape(v, [-1])
        l = v.get_shape()[0]
        mid = l//2 + 1
        val = tf.nn.top_k(v, mid).values
        if l % 2 == 1:
            return val[-1]
        else:
            return 0.5 * (val[-1] + val[-2])