Search code examples
pysparkgroup-concatpersist

Persisting loop dataframes for group concat functions in Pyspark


I'm trying to aggregate a spark dataframe up to a unique ID, selecting the first non-null value from that column for that ID given a sort column. Basically replicating MySQL's group_concat function.

The SO post here Spark SQL replacement for MySQL's GROUP_CONCAT aggregate function was very helpful in replicating the group_concat for a single column. I need to do this for a dynamic list of columns.

I would rather not have to copy this code for each column (dozen +, could be dynamic in the future), so am trying to implement in a loop (frowned on in spark I know!) given a list of column names. Loop runs successfully but, the previous iterations don't persist even when the intermediate df is cached/persisted (re: Cacheing and Loops in (Py)Spark).

Any help, pointers or a more elegant non-looping solution would be appreciated (not afraid to try a bit of scala if there is a functional programming approach more suitable)!

Given following df:

unique_id row_id first_name last_name middle_name score
1000000 1000002 Simmons Bonnie Darnell 88
1000000 1000006 Dowell Crawford Anne 87
1000000 1000007 NULL Eric Victor 89
1000000 1000000 Zachary Fields Narik 86
1000000 1000003 NULL NULL Warren 92
1000000 1000008 Paulette Ronald Irvin 85
group_column = "unique_id"
concat_list = ['first_name','last_name','middle_name']
sort_column = "score"
sort_order = False
df_final=df.select(group_column).distinct()
for i in concat_list:\
  df_helper=df
  df_helper=df_helper.groupBy(group_column)\
  .agg(sort_array(collect_list(struct(sort_column,i)),sort_order).alias('collect_list'))\
  .withColumn("sorted_list",col("collect_list."+str(i)))\
  .withColumn("first_item",slice(col("sorted_list"),1,1))\
  .withColumn(i,concat_ws(",",col("first_item")))\
  .drop("collect_list")\
  .drop("sorted_list")\
  .drop("first_item")
  print(i)
  df_final=df_final.join(df_helper,group_column,"inner")
  df_final.cache()
df_final.display() #I'm using databricks

My result looks like:

unique_id middle_name
1000000 Warren

My desired result is:

unique_id first_name last_name middle_name
1000000 Simmons Eric Warren

Second set of tables if they don't pretty print above


Solution

  • I found a solution to my own question: Add a .collect() call on my dataframe as I join to it, not a persist() or cache(); this will produce the expected dataframe.

    group_column = "unique_id"
    enter code hereconcat_list = ['first_name','last_name','middle_name']
    sort_column = "score"
    sort_order = False
    df_final=df.select(group_column).distinct()
    for i in concat_list:\
      df_helper=df
      df_helper=df_helper.groupBy(group_column)\
      .agg(sort_array(collect_list(struct(sort_column,i)),sort_order).alias('collect_list'))\
      .withColumn("sorted_list",col("collect_list."+str(i)))\
      .withColumn("first_item",slice(col("sorted_list"),1,1))\
      .withColumn(i,concat_ws(",",col("first_item")))\
      .drop("collect_list")\
      .drop("sorted_list")\
      .drop("first_item")
      print(i)
      df_final=df_final.join(df_helper,group_column,"inner")
      df_final.collect()
    df_final.display() #I'm using databricks