Search code examples
pysparkgroup-by

Pyspark GroupBy Flatten row


I have been trying to flatten the row in pyspark dataframe after a group by My dataset looks like this

|member_id|age|gender|      date|cost|
+---------+---+------+----------+-----------+
|        1| 35|  Male|2023-10-01|        200|
|        1| 35|  Male|2023-10-02|        210|
|        2| 28|Female|2023-10-01|        150|
|        2| 28|Female|2023-10-02|        160|
+---------+---+------+----------+-----------+

Now what I want as the output is:

[
    [
       [1,35, Male, 2023-10-01, 200],[1, 35, Male, 2023-10-02, 210]
    ],
    [
       [2, 28, Female, 2023-10-01, 150],[2, 28, Female, 2023-10-01, 160]
    ]
]

I have tried to achieve this but I ain't able to.


Solution

  • Simple Idea :

    First collect all columns into a list by F.array function

    and

    do a F.collect_list on a groupBy.

    from pyspark import SQLContext
    from pyspark.sql.functions import *
    import pyspark.sql.functions as F
    
    
    sc = SparkContext('local')
    sqlContext = SQLContext(sc)
    
    data1 = [
    
    [1, 35,"Male","2023-10-01",200],
    [1, 35,"Male","2023-10-02",210],
    [2, 28,"Female","2023-10-01",150],
    [2, 28,"Female","2023-10-02",160],
    
    ]
    
    columns =["member_id", "age", "gender", "date", "cost"]
    
    df1 = sqlContext.createDataFrame(data=data1, schema=columns)
    
    df1.show(n=10, truncate=False)
    print("Collect columns into list")
    df2 = df1.withColumn("colList", F.array(*df1.columns))
    df2.show(n=10, truncate=False)
    print("collect_list on the groupBy")
    df3 = df2.groupBy("member_id").agg(F.collect_list(F.col("colList")))
    df3.show(n=10, truncate=False)
    

    Output :

    +---------+---+------+----------+----+
    |member_id|age|gender|date      |cost|
    +---------+---+------+----------+----+
    |1        |35 |Male  |2023-10-01|200 |
    |1        |35 |Male  |2023-10-02|210 |
    |2        |28 |Female|2023-10-01|150 |
    |2        |28 |Female|2023-10-02|160 |
    +---------+---+------+----------+----+
    
    Collect columns into list
    +---------+---+------+----------+----+--------------------------------+
    |member_id|age|gender|date      |cost|colList                         |
    +---------+---+------+----------+----+--------------------------------+
    |1        |35 |Male  |2023-10-01|200 |[1, 35, Male, 2023-10-01, 200]  |
    |1        |35 |Male  |2023-10-02|210 |[1, 35, Male, 2023-10-02, 210]  |
    |2        |28 |Female|2023-10-01|150 |[2, 28, Female, 2023-10-01, 150]|
    |2        |28 |Female|2023-10-02|160 |[2, 28, Female, 2023-10-02, 160]|
    +---------+---+------+----------+----+--------------------------------+
    
    collect_list on the groupBy
    +---------+--------------------------------------------------------------------+
    |member_id|collect_list(colList)                                               |
    +---------+--------------------------------------------------------------------+
    |1        |[[1, 35, Male, 2023-10-01, 200], [1, 35, Male, 2023-10-02, 210]]    |
    |2        |[[2, 28, Female, 2023-10-01, 150], [2, 28, Female, 2023-10-02, 160]]|
    +---------+--------------------------------------------------------------------+