I am trying to create a new column of lists in Pyspark using a groupby aggregation on existing set of columns. An example input data frame is provided below:
id | date | value
1 |2014-01-03 | 10
1 |2014-01-04 | 5
1 |2014-01-05 | 15
1 |2014-01-06 | 20
2 |2014-02-10 | 100
2 |2014-03-11 | 500
2 |2014-04-15 | 1500
The expected output is:
id | value_list
1 | [10, 5, 15, 20]
2 | [100, 500, 1500]
The values within a list are sorted by the date.
I tried using collect_list as follows:
from pyspark.sql import functions as F
ordered_df = input_df.orderBy(['id','date'],ascending = True)
grouped_df = ordered_df.groupby("id").agg(F.collect_list("value"))
But collect_list doesn't guarantee order even if I sort the input data frame by date before aggregation.
Could someone help on how to do aggregation by preserving the order based on a second (date) variable?
EDIT: pyspark.sql.functions.array_sort
was added in PySpark 2.4, which operates exactly the same as the sorter
UDF defined below and will generally be more performant. Leaving the old answer for posterity.
For PySpark <2.4 only:
If you collect both dates and values as a list, you can sort the resulting column according to date using and udf
, and then keep only the values in the result.
import operator
import pyspark.sql.functions as F
# create list column
grouped_df = input_df.groupby("id") \
.agg(F.collect_list(F.struct("date", "value")) \
# define udf
def sorter(l):
res = sorted(l, key=operator.itemgetter(0))
return [item[1] for item in res]
sort_udf = F.udf(sorter)
# test
grouped_df.select("id", sort_udf("list_col") \
.alias("sorted_list")) \
.show(truncate = False)
|id |sorted_list |
|1 |[10, 5, 15, 20] |
|2 |[100, 500, 1500]|