I am trying to find the sum of value column based on code column in a table in the below format using pyspark. Here in this example, I have provided only 3 code and its respective value columns, but in real scenario it can be up to 100.
Table A
id1 | item | code_1 | Value_1 | code_2 | Value_2 | code_3 | value_3 |
---|---|---|---|---|---|---|---|
100 | 1 | A | 5 | X | 10 | L | 20 |
100 | 2 | B | 5 | L | 10 | A | 20 |
Expected output:
id1 | item | sum_A | sum_X | sum_L | sum B | Total |
---|---|---|---|---|---|---|
100 | 1 | 25 | 10 | 30 | 5 | 70 |
100 | 2 | 25 | 10 | 30 | 5 | 70 |
Can someone help me to find a logic to accomplish this output.
You can use Stack + groupBy + Pivot
functions for this case.
Example:
df.show()
df1 = df.select("id1",'item',expr("stack(3,code_1,value_1,code_2,value_2,code_3,value_3)")).\
groupBy("id1","col0").agg(sum("col1").alias("sum")).\
withColumn("col0",concat(lit("sum_"),col("col0")))
df.select("id1","item").distinct().\
join(df1,['id1']).\
groupBy("id1","item").\
pivot("col0").\
agg(first(col("sum").alias("sum_"))).\
show()
Output:
#sample data
+---+----+------+-------+------+-------+------+-------+
|id1|item|code_1|value_1|code_2|value_2|code_3|value_3|
+---+----+------+-------+------+-------+------+-------+
|100| 1| A| 5| X| 10| L| 20|
|100| 2| B| 5| L| 10| A| 20|
+---+----+------+-------+------+-------+------+-------+
#output
+---+----+-----+-----+-----+-----+
|id1|item|sum_A|sum_B|sum_L|sum_X|
+---+----+-----+-----+-----+-----+
|100| 2| 25| 5| 30| 10|
|100| 1| 25| 5| 30| 10|
+---+----+-----+-----+-----+-----+
UPDATE:
Dynamic sql:
req_cols = [c for c in df.columns if c.startswith("code_") or c.startswith("value_")]
sql_expr = "stack("+ str(int(len(req_cols)/2))+"," +','.join(req_cols) +")"
df.show()
df1 = df.select("id1",'item',expr(f"{sql_expr}")).\
groupBy("id1","col0").agg(sum("col1").alias("sum")).\
withColumn("col0",concat(lit("sum_"),col("col0")))
df.select("id1","item").distinct().\
join(df1,['id1']).\
groupBy("id1","item").\
pivot("col0").\
agg(first(col("sum").alias("sum_"))).\
show()