Search code examples
pythonpysparkspark-streamingrdddstream

How to calculate average by category in pyspark streaming?


I have csv data coming as DStreams from traffic counters. Sample is as follows

`Location,Vehicle,Speed,`
`tracker1,car,57,`
`tracker1,car,90,`
`tracker1,mbike,81,`
`tracker1,mbike,65,`
`tracker2,car,69,`
`tracker2,car,34,`
`tracker2,mbike,29,`
`tracker2,mbike,76,`

I want to calculate average speed (for each location) by vehicle category.

I want to achieve this by transformations. Below is the result i am looking for.

Location |  Car | MBike
Tracker 1| 73.5 |  73
Tracker 2| 51.5 |  52.5

Solution

  • I'm not sure exactaly what you want, but if it's avarage speed by vehicle, by location, than you can use a Window function:

    df = spark.createDataFrame(
        [
         ('tracker1','car','57')
        ,('tracker1','car','90')
        ,('tracker1','mbike','81')
        ,('tracker1','mbike','65')
        ,('tracker2','car','69')
        ,('tracker2','car','34')
        ,('tracker2','mbike','29')
        ,('tracker2','mbike','76')
        ],
        ['Location','Vehicle','Speed']
    )
    
    from pyspark.sql import Window
    import pyspark.sql.functions as F
    
    w = Window.partitionBy("Location","Vehicle")
    
    df_pivot = df\
                .withColumn('avg_speed', F.avg(F.col('Speed')).over(w))\
                .groupby('Location','Vehicle', 'avg_speed')\
                .pivot("Vehicle")\
                .agg(F.first('avg_speed'))\
                .drop('Vehicle', 'avg_speed')
    
    expr = {x: "sum" for x in df_pivot.columns if x is not df_pivot.columns[0]}
    
    print(expr)
    
    df_almost_final = df_pivot\
                        .groupBy("Location")\
                        .agg(expr)\
                        .orderBy('Location')
    
    df_final = df_almost_final.select([F.col(c).alias(c.replace('sum(','').replace(')','')) for c in df_almost_final.columns])
    
    
    df_final.show()
    
    
    
    # +--------+-----+----+
    # |Location|mbike| car|
    # +--------+-----+----+
    # |tracker1| 73.0|73.5|
    # |tracker2| 52.5|51.5|
    # +--------+-----+----+