Search code examples
pythondataframepysparkmedian

Pyspark Dataframe - Median Without Numpy Or Other Libraries


I've been working on this in pyspark for a while and I'm stuck. I'm trying to get the median of the column numbers for its respective window. I need to do this without the use of other libraries such as numpy etc.

So far (as depicted below), I've grouped the dataset into windows by the column id. This is depicted by the column row_numbers which shows you what each window looks like. There are three windows in this dataframe example.

This is what I would like:

I want each row to also contain the median of the window of the column id without taking its own row into consideration. The location of the median I require is in my function below called median_loc

Example: For the row_number = 5, I need to find the median for the rows 1 to 4 above it (i.e. not including row_number 5). The median therefore (per my requirement) is the average of the column id in the same window where row_number = 1 and row_number = 2 i.e

Date        id      numbers row_number  med_loc
2017-03-02  group 1   98        1       [1]
2017-04-01  group 1   50        2       [1]
2018-03-02  group 1   5         3       [1, 2]
2016-03-01  group 2   49        1       [1]
2016-12-22  group 2   81        2       [1]
2017-12-31  group 2   91        3       [1, 2]
2018-08-08  group 2   19        4       [2]
2018-09-25  group 2   52        5       [1, 2]
2017-01-01  group 3   75        1       [1]
2018-12-12  group 3   17        2       [1]

The code I used to get the last column med_loc is as follow

def median_loc(sz):
    if sz == 1 or sz == 0:
        kth = [1]
        return kth
    elif sz % 2 == 0 and sz > 1:
        szh = sz // 2
        kth = [szh - 1, szh] if szh != 1 else [1, 2]
        return kth
    elif sz % 2 != 0 and sz > 1:
        kth = [(sz + 1) // 2]
        return kth


sqlContext.udf.register("median_location", median_loc)

median_loc = F.udf(median_loc)

df = df.withColumn("med_loc", median_loc(df.row_number)-1)

Note : I only made them look like a list for easier understanding. It is just to show where the median is located in the respective window. It is just for the easier understanding of folks reading this on Stack Overflow

The output that I desire is as follows:

Date        id      numbers row_number  med_loc     median
2017-03-02  group 1   98        1       [1]           98
2017-04-01  group 1   50        2       [1]           98
2018-03-02  group 1   5         3       [1, 2]        74
2016-03-01  group 2   49        1       [1]           49
2016-12-22  group 2   81        2       [1]           49
2017-12-31  group 2   91        3       [1, 2]        65
2018-08-08  group 2   19        4       [2]           81
2018-09-25  group 2   52        5       [1, 2]        65
2017-01-01  group 3   75        1       [1]           75
2018-12-12  group 3   17        2       [1]           75

Basically, the way to get the median so far is something like this:

  1. If med_loc is one digit (i.e if the list has just one digit such [1] or [3] et.) then median = df.numbers where df.row_number = df.med_loc

  2. If med_loc is two digits (i.e if the list has two digits such as [1,2] or [2, 3] etc.) then median = average(df.numbers) where df.row_number in df.med_loc

I cannot stress enough how important it is for me not to use other libraries such as numpy etc. to get the output. There are other solutions that I looked at that used np.median and they work, however, that is not my requirement at this time.

I'm sorry if this explanation is so winded and if I'm complicating it. I've been looking at this for days and can't seem to figure it out. I also tried to use the percent_rank function but I'm not able to figure it out because not all windows contain 0.5 percentile.

Any help will be appreciated.


Solution

  • Suppose you start with the following DataFrame, df:

    +----------+-------+-------+
    |      Date|     id|numbers|
    +----------+-------+-------+
    |2017-03-02|group 1|     98|
    |2017-04-01|group 1|     50|
    |2018-03-02|group 1|      5|
    |2016-03-01|group 2|     49|
    |2016-12-22|group 2|     81|
    |2017-12-31|group 2|     91|
    |2018-08-08|group 2|     19|
    |2018-09-25|group 2|     52|
    |2017-01-01|group 3|     75|
    |2018-12-12|group 3|     17|
    +----------+-------+-------+
    

    Order DataFrame

    First add the row_number as you did in your example and assign the output to a new DataFrame df2:

    import pyspark.sql.functions as f
    from pyspark.sql import Window
    
    df2 = df.select(
        "*", f.row_number().over(Window.partitionBy("id").orderBy("Date")).alias("row_number")
    )
    df2.show()
    #+----------+-------+-------+----------+
    #|      Date|     id|numbers|row_number|
    #+----------+-------+-------+----------+
    #|2017-03-02|group 1|     98|         1|
    #|2017-04-01|group 1|     50|         2|
    #|2018-03-02|group 1|      5|         3|
    #|2016-03-01|group 2|     49|         1|
    #|2016-12-22|group 2|     81|         2|
    #|2017-12-31|group 2|     91|         3|
    #|2018-08-08|group 2|     19|         4|
    #|2018-09-25|group 2|     52|         5|
    #|2017-01-01|group 3|     75|         1|
    #|2018-12-12|group 3|     17|         2|
    #+----------+-------+-------+----------+
    

    Collect Values for Median

    Now you can join df2 to itself on the id column with the condition that the left side's row number is 1 or it is greater than the right side's row_number. Then group by the left DataFrame's ("id", "Date", "row_number") and collect the numbers from the right DataFrame into a list.

    For the case when the row_number is equal to 1, we only want to keep the first element of this collected list. Otherwise keep all the numbers, but sort them because we need to have them ordered to calculate the median.

    Call this intermediate DataFrame df3:

    df3 = df2.alias("l").join(df2.alias("r"), on="id", how="left")\
        .where("l.row_number = 1 OR (r.row_number < l.row_number)")\
        .groupBy("l.id", "l.Date", "l.row_number")\
        .agg(f.collect_list("r.numbers").alias("numbers"))\
        .select(
            "id",
            "Date",
            "row_number",
            f.when(
                f.col("row_number") == 1,
                f.array([f.col("numbers").getItem(0)])
            ).otherwise(f.sort_array("numbers")).alias("numbers")
        )
    df3.show()
    #+-------+----------+----------+----------------+
    #|     id|      Date|row_number|         numbers|
    #+-------+----------+----------+----------------+
    #|group 1|2017-03-02|         1|            [98]|
    #|group 1|2017-04-01|         2|            [98]|
    #|group 1|2018-03-02|         3|        [50, 98]|
    #|group 2|2016-03-01|         1|            [49]|
    #|group 2|2016-12-22|         2|            [49]|
    #|group 2|2017-12-31|         3|        [49, 81]|
    #|group 2|2018-08-08|         4|    [49, 81, 91]|
    #|group 2|2018-09-25|         5|[19, 49, 81, 91]|
    #|group 3|2017-01-01|         1|            [75]|
    #|group 3|2018-12-12|         2|            [75]|
    #+-------+----------+----------+----------------+
    

    Notice that the numbers column of df3 has a list of the appropriate values for which we want to find the median.

    Compute Median

    Since your version of Spark is greater than 2.1, you can use pyspark.sql.functions.posexplode() to compute the median from this list of values. For lower versions of spark, you would need to use a udf.

    First create 2 helper columns in df3:

    • isEven: A Boolean to indicate if the numbers array has an even number of elements
    • middle: The index of the middle of the array, which is the floor of the length / 2

    After these columns are created, explode the array using posexplode(), which will return two new columns: pos and col. We then filter out the resultant DataFrame to only keep the positions that we need to compute the median.

    The logic on which positions to keep is as follows:

    • If isEven is False, we only keep the middle position
    • IF isEven is True, we keep the middle position and the middle position - 1.

    Finally group by the id and Date and average the remaining numbers.

    df3.select(
        "*",
        f.when(
            (f.size("numbers") % 2) == 0,
            f.lit(True)
        ).otherwise(f.lit(False)).alias("isEven"),
        f.floor(f.size("numbers")/2).alias("middle")
    ).select(
            "id", 
            "Date",
            "middle",
            f.posexplode("numbers")
    ).where(
        "(isEven=False AND middle=pos) OR (isEven=True AND pos BETWEEN middle-1 AND middle)"
    ).groupby("id", "Date").agg(f.avg("col").alias("median")).show()
    #+-------+----------+------+
    #|     id|      Date|median|
    #+-------+----------+------+
    #|group 1|2017-03-02|  98.0|
    #|group 1|2017-04-01|  98.0|
    #|group 1|2018-03-02|  74.0|
    #|group 2|2016-03-01|  49.0|
    #|group 2|2016-12-22|  49.0|
    #|group 2|2017-12-31|  65.0|
    #|group 2|2018-08-08|  81.0|
    #|group 2|2018-09-25|  65.0|
    #|group 3|2017-01-01|  75.0|
    #|group 3|2018-12-12|  75.0|
    #+-------+----------+------+