Search code examples
apache-sparkpysparkapache-spark-sql

In pyspark, is it possible to groupby and do a aggregation with a where conditions?


I have the current dataframe df_A:

+-------------+-------------+--------+-----------+--------+
|           id|         Type|id_count|     Value1|  Value2|
+-------------+-------------+--------+- ---------+--------+
|           18|          AAA|       2|      null|     null|
|           18|          BBB|       2|      null|     null|
|           16|          BBB|       2|      null|     null|
|           16|          CCC|       2|      null|     null|
|           17|          CCC|       1|      null|     null|
+-------------+-------------+--------+----------+---------+

When count is 2 and type is NOT equal to "AAA", I want to sum Value1 for this id and remove the duplicate. So that the result will be like this:

+-------------+-------------+--------+-----------+--------+
|           id|         Type|id_count|     Value1|  Value2|
+-------------+-------------+--------+- ---------+--------+
|           18|          AAA|       2|      null|     null|
|           18|          BBB|       2|      null|     null|
|           16|      CCC+BBB|       2|      null|     null|
|           17|          CCC|       1|      null|     null|
+-------------+-------------+--------+----------+---------+

To groupby and sum, it would be straight forward: df_B = df_A.groupby(id).sum('Value1')

But I need to use filter, when or where in combination with groupby, but I don't find a way. How can this be done? Another approach?


Solution

  • You can filter the initial dataframe to get 2 dataframes, let's call df1 the first dataframe that respect your condition (count is 2 and type is NOT equal to "AAA") and df2 the other dataframe, apply your groupBy on df1 and then union the 2 dataframes, here's a code:

    df = spark.createDataFrame([
        (18, "AAA", 2, 1, 5),
        (18, "BBB", 2, 2, 4),
        (16, "BBB", 2, 3, 3),
        (16, "CCC", 2, 4, 2),
        (17, "CCC", 1, 5, 1)
    ], ["id", "Type", "id_count", "Value1", "Value2"])
    
    filter_cond = (df['id_count'] == 2) & (df['Type'] != 'AAA')
    df1 = df.filter(filter_cond)
    df2 = df.filter(~filter_cond)
    df1.groupby("id").agg(
        concat_ws("+", collect_list("Type")).alias("Type"),
        first("id_count"),
        sum("Value1").alias("Value1"),
        sum("Value2").alias("Value2")
    ).union(df2).show()