Search code examples
pythonpysparkapache-spark-sql

How to groupBy on two columns and work out avg total value for each grouped column using pyspark


I have the following DataFrame and using Pyspark, I'm trying to get the following answers:

  1. Total Fare by Pick
  2. Total Tip by Pick
  3. Avg Drag by Pick
  4. Avg Drag by Drop
Pick Drop Fare Tip Drag
1 1 4.00 4.00 1.00
1 2 5.00 10.00 8.00
1 2 5.00 15.00 12.00
3 2 11.00 12.00 17.00
3 5 41.00 25.00 13.00
4 6 50.00 70.00 2.00

My Query is so far like this:

from pyspark.sql import functions as func
from pyspark.sql.functions import desc

data = [
    (1, 1, 4.00, 4.00, 1.00),
    (1, 2, 5.00, 10.00, 8.00),
    (1, 2, 5.00, 15.00, 12.00),
    (3, 2, 11.00, 12.00, 17.00),
    (3, 5, 41.00, 25.00, 13.00),
    (4, 6, 50.00, 70.00, 2.00)
]

columns = ["Pick", "Drop", "Fare", "Tip", "Drag"]
df = spark.createDataFrame(data, columns)


df.groupBy('Pick', 'Drop') \
    .agg(
        func.sum('Fare').alias('FarePick'),
        func.sum('Tip').alias('TipPick'),
        func.avg('Drag').alias('AvgDragPick'),
        func.avg('Drag').alias('AvgDragDrop')) \
    .orderBy('Pick').show()

However, I don't think it seems to be correct. I'm abit stuck on (4) because the groupby does not seem correct. Can anyone suggest correction here. The output needs to be in One (1) table together such as:

Pick Drop FarePick TipPick AvgDragPick AvgDragDrop

Solution

  • In order you want to have all columns in it just use Window functions.

    from pyspark.sql import functions as f
    from pyspark.sql import Window
    
    data = [
        (1, 1, 4.00, 4.00, 1.00),
        (1, 2, 5.00, 10.00, 8.00),
        (1, 2, 5.00, 15.00, 12.00),
        (3, 2, 11.00, 12.00, 17.00),
        (3, 5, 41.00, 25.00, 13.00),
        (4, 6, 50.00, 70.00, 2.00)
    ]
    
    columns = ["Pick", "Drop", "Fare", "Tip", "Drag"]
    df = spark.createDataFrame(data, columns)
    
    df_new = (
        df
        .withColumn("TotalFarePick", f.sum("Fare").over(Window.partitionBy("Pick")))
        .withColumn("TotalTipPick", f.sum("Tip").over(Window.partitionBy("Pick")))
        .withColumn("AvgDragPick", f.avg("Drag").over(Window.partitionBy("Pick")))
        .withColumn("AvgDragDrop", f.avg("Drag").over(Window.partitionBy("Drop")))
        .drop("Fare", "Tip", "Drag")
    )
    
    df_new.show()
    
    

    Also please do not use \ as they should not be used in newer python versions.

    See for that https://peps.python.org/pep-0008/ :

    The preferred way of wrapping long lines is by using Python’s implied line continuation inside parentheses, brackets and braces. Long lines can be broken over multiple lines by wrapping expressions in parentheses. These should be used in preference to using a backslash for line continuation.