Search code examples
algorithmoptimizationmedianmedian-of-medians

Find a median of N^2 numbers having memory for N of them


I was trying to learn about distributed computing and came across a problem of finding median of a large set of numbers:

Assume that we have a large set of numbers (lets say number of elements is N*K) that cannot fit into memory (size N). How do we find the median of this data? Assume that the operations performed on the memory are independent i.e. we can consider that there are K machines each that can process at most N elements.

I thought that median of medians can be used for this purpose. We can load N numbers at a time into memory. We find the median of that set in O(logN) time and save it.

Then we save all these K medians and find out the median of medians. Again O(logK), so far the complexity has been O(K*logN + logK).

But this median of medians is just an approximate median. I think it will be optimal to use it as a pivot to get a best case performance, but for that we will need to fit all the N*K numbers in memory.

How can we find the actual median of the set now that we have a good approximate pivot?


Solution

  • Why don't you build a histogram? I.e. the number of cases (values) that fall into each of several categories. The categories should be a consecutive, non-overlapping intervals of a variable.

    With this histogram you can make a first estimation of the median (i.e., median is between [a,b]), and know how many values fall into this interval (H). If H<=N, read the numbers again, ignoring these outside this interval, and moving to RAM the numbers within the interval. Find the median.

    If H>N, do a new partition of the interval and repeat the procedure. It shouldn't take more than 2 or 3 iterations.

    Note that for each partition you only need to store a, b, a Delta and the array with the number of values that fall into each subinterval.

    EDIT. It turnet out to be a bit more complicated that I expected. In each iteration after estimating the interval the median falls into, we should also consider "how much" histogram we leave on the right and on the left of this interval. I changed the stop condition too. Anyway, I did a C++ implementation.

    #include <iostream>
    #include <algorithm>
    #include <time.h>
    #include <stdlib.h>
    
    //This is N^2... or just the number of values in your array,
    //note that we never modify it except at the end (just for sorting
    //and testing purposes).
    #define N2 1000000
    //Number of elements in the histogram. Must be >2
    #define HISTN 1000
    
    double findmedian (double *values, double min, double max);
    int getindex (int *hist);
    void put (int *hist, double min, double max, double val, double delta);
    
    
    int main ()
    {
        //Set max and min to the max/min values your array variables can hold,
        //calculate it, or maybe we know that they are bounded
        double max=1000.0;
        double min=0.0;
        double delta;
        double values[N2];
        int hist[HISTN];
        int ind;
        double median;
        int iter=0;
        //Initialize with random values   
        srand ((unsigned) (time(0)));
        for (int i=0; i<N2; ++i)
            values[i]=((double)rand()/(double)RAND_MAX);
    
        double imin=min;
        double imax=max;
    
        clock_t begin=clock(); 
        while (1) {
            iter++;
            for (int i=0; i<HISTN; ++i)
                hist[i]=0;
    
            delta=(imax-imin)/HISTN;
            for (int j=0; j<N2; ++j)
                put (hist, imin, imax, values[j], delta);
    
            ind=getindex (hist);
            imax=imin;
            imin=imin+delta*ind;
            imax=imax+delta*(ind+1);
    
            if (hist[ind]==1 || imax-imin<=DBL_MIN) {
                median=findmedian (values, imin, imax);
                break;
            }   
        }
    
        clock_t end=clock();
        std::cout << "Median with our algorithm: " << median << " - " << iter << "iterations of the algorithm" << std::endl; 
        double time=(double)(end-begin)/CLOCKS_PER_SEC;
        std::cout << "Time: " << time << std::endl;  
    
        //Let's compare our result with the median calculated after sorting the
        //array
        //Should be values[(int)N2/2] if N2 is odd
        begin=clock();
        std::sort (values, values+N2);
        std::cout << "Median after sorting: " << values[(int)N2/2-1] << std::endl;
        end=clock();
        time=(double)(end-begin)/CLOCKS_PER_SEC;
        std::cout << "Time: " << time << std::endl;  
    
        return 0;
    }
    
    double findmedian (double *values, double min, double max) {
        for (int i=0; i<N2; ++i) 
            if (values[i]>=min && values[i]<=max)
                return values[i];
    
        return 0;
    }
    
    int getindex (int *hist)
    {
        static int pd=0;
        int left=0;
        int right=0; 
        int i;
    
        for (int k=0; k<HISTN; k++)
            right+=hist[k];
    
        for (i=0; i<HISTN; i++) {
            right-=hist[i];
            if (i>0)
                left+=hist[i-1];
            if (hist[i]>0) {
                if (pd+right-left<=hist[i]) {
                    pd=pd+right-left;
                    break;
                }
            }
    
        }
    
        return i;
    }
    
    void put (int *hist, double min, double max, double val, double delta)
    {
        int pos;
        if (val<min || val>max)
            return;
    
        pos=(val-min)/delta;
        hist[pos]++;
        return;
    }
    

    I also included a naive calculation of the median (sorting) in order to compare with the results of the algorithm. 4 or 5 iterations are enough. It means we just need to read the set from network or HDD 4-5 times.

    Some results:

    N2=10000
    HISTN=100
    
    Median with our algorithm: 0.497143 - 4 iterations of the algorithm
    Time: 0.000787
    Median after sorting: 0.497143
    Time: 0.001626
    
    (Algorithm is 2 times faster)
    
    N2=1000000
    HISTN=1000
    
    Median with our algorithm: 0.500665 - 4 iterations of the algorithm
    Time: 0.028874
    Median after sorting: 0.500665
    Time: 0.097498
    
    (Algorithm is ~3 times faster)
    

    If you want to parallelize the algorithm, each machine can have N elements and calculate the histogram. Once it is calculated, they would send it to the master machine, that would sum all the histograms (easy, it can be really small... the algorithm even works with histograms of 2 intervals). Then it would send new instructions (i.e. the new interval) to the slave machines in order to calculate new histograms. Note that each machine does not need to have any knowledge about the N elements the other machines own.