Search code examples
pythonapache-sparkpyspark

collect_list by preserving order based on another variable


I am trying to create a new column of lists in Pyspark using a groupby aggregation on existing set of columns. An example input data frame is provided below:

------------------------
id | date        | value
------------------------
1  |2014-01-03   | 10 
1  |2014-01-04   | 5
1  |2014-01-05   | 15
1  |2014-01-06   | 20
2  |2014-02-10   | 100   
2  |2014-03-11   | 500
2  |2014-04-15   | 1500

The expected output is:

id | value_list
------------------------
1  | [10, 5, 15, 20]
2  | [100, 500, 1500]

The values within a list are sorted by the date.

I tried using collect_list as follows:

from pyspark.sql import functions as F
ordered_df = input_df.orderBy(['id','date'],ascending = True)
grouped_df = ordered_df.groupby("id").agg(F.collect_list("value"))

But collect_list doesn't guarantee order even if I sort the input data frame by date before aggregation.

Could someone help on how to do aggregation by preserving the order based on a second (date) variable?


Solution

  • EDIT: pyspark.sql.functions.array_sort was added in PySpark 2.4, which operates exactly the same as the sorter UDF defined below and will generally be more performant. Leaving the old answer for posterity.

    For PySpark <2.4 only:

    If you collect both dates and values as a list, you can sort the resulting column according to date using and udf, and then keep only the values in the result.

    import operator
    import pyspark.sql.functions as F
    
    # create list column
    grouped_df = input_df.groupby("id") \
                   .agg(F.collect_list(F.struct("date", "value")) \
                   .alias("list_col"))
    
    # define udf
    def sorter(l):
      res = sorted(l, key=operator.itemgetter(0))
      return [item[1] for item in res]
    
    sort_udf = F.udf(sorter)
    
    # test
    grouped_df.select("id", sort_udf("list_col") \
      .alias("sorted_list")) \
      .show(truncate = False)
    +---+----------------+
    |id |sorted_list     |
    +---+----------------+
    |1  |[10, 5, 15, 20] |
    |2  |[100, 500, 1500]|
    +---+----------------+