Search code examples
pythonmedian

Is it possible to compute median while data is still being generated? Python online median calculator


I've seen a broader version of this question asked, where the individual was looking for more than one summary statistic, but I have not seen a solution presented. I'm only interested in the median, in Python, here.

Let's say I'm generating a million values in a loop. I cannot save the million values into a list and compute median once I'm done, because of memory issues. Is it possible to compute median as I go? For mean, I'd just sum values progressively, and once done divide by a million. For median the answer does not seem as intuitive.

I'm stuck on the "thought experiment" section of this, so I haven't been able to really try anything that I think might work. I'm not sure if this is an algorithm that has been implemented, but I can't find it if it has been.


Solution

  • This won't be possible unless your idea of "values" is restricted in some exploitable way; and/or you can make more than one pass over the data; and/or you're willing to store stuff on disk too. Suppose you know there are 5 values, all distinct integers, and you know the first 3 are 5, 6, 7. The median will be one of those, but at this point you can't know which one, so you have to remember all of them. If 1 and 2 come next, the median is 5; if 4 and 8 next, 6; if 8 and 9 next, it's 7.

    This obviously generalizes to any odd number of values range(i, i + 2*N+1), at the point you've seen the first N+1 of them: the median can turn out to be any one of those first N+1, so unless there's something exploitable about the nature of the values, you have to remember all of them at that point.

    An example of something exploitable: you know there are at most 100 distinct values. Then you can use a dict to count how many of each appear, and easily calculate the median at the end from that compressed representation of the distribution.

    Approximating

    For reasons already mentioned, there is no "shortcut" here to be had in general. But I'll attach Python code for a reasonable one-pass approximation method, as detailed in "The Remedian: A Robust Averaging Method for Large Data Sets". That paper also points to other approximation methods.

    The key: pick an odd integer B greater than 1. Then successive elements are stored in a buffer until B of them have been recorded. At that point, the median of those advances to the next level, and the buffer is cleared. Their median remains the only memory of those B elements retained.

    The same pattern continues at deeper levels too: after B of those median-of-B medians have been recorded, the median of those advances to the next level, and the second-level buffer is cleared. The median advanced then remains the only memory of the B**2 elements that went into it.

    And so on. At worst it can require storing B * log(N, B) values, where N is the total number of elements. In Python it's easy to code it so buffers are created as needed, so N doesn't need to be known in advance.

    If B >= N, the method is exact, but then you've also stored every element. If B < N, it's an approximation to the median. See the paper for details - it's quite involved. Here's a case that makes it look very good ;-)

    >>> import random
    >>> xs = [random.random() for i in range(1000001)]
    >>> sorted(xs)[500000] # true median
    0.5006315438367565
    >>> w = MedianEst(11)
    >>> for x in xs:
    ...     w.add(x)
    >>> w.get()
    0.5008443883489089
    

    Perhaps surprisingly, it does worse if the inputs are added in sorted order:

    >>> w.clear()
    >>> for x in sorted(xs):
    ...     w.add(x)
    >>> w.get()
    0.5021045181828147
    

    User beware! Here's the code:

    class MedianEst:
        def __init__(self, B):
            assert B > 1 and B & 1
            self.B = B
            self.half = B >> 1
            self.clear()
    
        def add(self, x):
            for xs in self.a:
                xs.append(x)
                if len(xs) == self.B:
                    x = sorted(xs)[self.half]
                    xs.clear()
                else:
                    break
            else:
                self.a.append([x])
    
        def get(self):
            total = 0
            weight = 1
            accum = []
            for xs in self.a:
                total += len(xs) * weight
                accum.extend((x, weight) for x in xs)
                weight *= self.B
            # `total` elements in all
            limit = total // 2 + 1
            total = 0
            for x, weight in sorted(accum):
                total += weight
                if total >= limit:
                    return x
    
        def clear(self):
            self.a = []