Search code examples
pyspark

pyspark aggregation based on key and value expanded in multiple columns


I am trying to find the sum of value column based on code column in a table in the below format using pyspark. Here in this example, I have provided only 3 code and its respective value columns, but in real scenario it can be up to 100.

Table A

id1 item code_1 Value_1 code_2 Value_2 code_3 value_3
100 1 A 5 X 10 L 20
100 2 B 5 L 10 A 20

Expected output:

id1 item sum_A sum_X sum_L sum B Total
100 1 25 10 30 5 70
100 2 25 10 30 5 70

Can someone help me to find a logic to accomplish this output.


Solution

  • You can use Stack + groupBy + Pivot functions for this case.

    Example:

    df.show()
    df1 = df.select("id1",'item',expr("stack(3,code_1,value_1,code_2,value_2,code_3,value_3)")).\
      groupBy("id1","col0").agg(sum("col1").alias("sum")).\
        withColumn("col0",concat(lit("sum_"),col("col0")))
    df.select("id1","item").distinct().\
      join(df1,['id1']).\
        groupBy("id1","item").\
          pivot("col0").\
            agg(first(col("sum").alias("sum_"))).\
            show()
    

    Output:

    #sample data
    +---+----+------+-------+------+-------+------+-------+
    |id1|item|code_1|value_1|code_2|value_2|code_3|value_3|
    +---+----+------+-------+------+-------+------+-------+
    |100|   1|     A|      5|     X|     10|     L|     20|
    |100|   2|     B|      5|     L|     10|     A|     20|
    +---+----+------+-------+------+-------+------+-------+
    
    #output    
    +---+----+-----+-----+-----+-----+
    |id1|item|sum_A|sum_B|sum_L|sum_X|
    +---+----+-----+-----+-----+-----+
    |100|   2|   25|    5|   30|   10|
    |100|   1|   25|    5|   30|   10|
    +---+----+-----+-----+-----+-----+
    

    UPDATE:

    Dynamic sql:

    req_cols = [c for c in df.columns if c.startswith("code_") or c.startswith("value_")]
    
    sql_expr = "stack("+ str(int(len(req_cols)/2))+"," +','.join(req_cols) +")"
    
    df.show()
    df1 = df.select("id1",'item',expr(f"{sql_expr}")).\
      groupBy("id1","col0").agg(sum("col1").alias("sum")).\
        withColumn("col0",concat(lit("sum_"),col("col0")))
    df.select("id1","item").distinct().\
      join(df1,['id1']).\
        groupBy("id1","item").\
          pivot("col0").\
            agg(first(col("sum").alias("sum_"))).\
            show()