Search code examples
pysparkaggregate-functions

how to use "windows" over "aggregation using groupby followed by join"?


I have a dataframe (df) with below structure: Grid_ID, Latitude, Longitude, DateTimeStamp

Input:

Grid_ID    Latitude Longitude DateTimeStamp 
Grid_1     Lat1     Long1     2021-06-30 00:00:00
Grid_1     Lat1     Long1     2021-06-30 00:01:00
Grid_1     Lat1     Long1     2021-06-30 00:02:00
Grid_1     Lat2     Long2     2021-07-01 00:00:00
Grid_1     Lat2     Long2     2021-07-01 00:01:00
Grid_1     Lat2     Long2     2021-07-01 00:02:00

and Grid_ID has two sets of lat and long that are mutually exclusive based upon Date column. i.e. when Date <= 06/30/2021, Latitude/Longitude = Lat1/Long1 and when Date > 06/30/2021, Latitude/Longitude = Lat2/Long2, respectively

I need to create new columns (Corrected_Lat and Corrected_Long) and assign Lat2/Long2

I am using groupBy and agg as below to do the above:

df_dated = df.withColumn("date", F.to_date("DateTimeStamp")) \
                             .filter(F.col("date") == "2021-07-01") \
                             .groupBy("Grid_ID") \
             .agg(F.collect_set("Latitude").getItem(0).cast("float").alias("corrected_lat"),
                  F.collect_set("Longitude").getItem(0).cast("float").alias("corrected_long")) \
             .withColumnRenamed("Grid_ID", "Grid_ID_dated") \
             .select("Grid_ID_dated", "corrected_lat", "corrected_long")
df_final = df.join(df_dated, on=[df.Grid_ID == df_dated.Grid_ID_dated],
                   how="inner") \
             .select(*df.columns, "corrected_lat", "corrected_long")

Output:

Grid_ID    Latitude Longitude DateTimeStamp        corrected_lat   corrected_long
Grid_1     Lat1     Long1     2021-06-30 00:00:00  Lat2            Long2
Grid_1     Lat1     Long1     2021-06-30 00:01:00  Lat2            Long2
Grid_1     Lat1     Long1     2021-06-30 00:02:00  Lat2            Long2
Grid_1     Lat2     Long2     2021-07-01 00:00:00  Lat2            Long2
Grid_1     Lat2     Long2     2021-07-01 00:01:00  Lat2            Long2
Grid_1     Lat2     Long2     2021-07-01 00:02:00  Lat2            Long2

But I am wondering if windows function can be used here and would it be faster than first approach using groupBy and agg?

Any other approach that is faster is certainly appreciated.


Solution

  • This will apply the latest Latitude/Longitude value to all rows (per Grid_ID):

    w = (
        Window.partitionBy('Grid_ID')
        .orderBy("DateTimeStamp")
        .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
    )
    df_final = (
        df.withColumn("date", F.to_date("DateTimeStamp"))
        .select(*df.columns,
                F.last('Latitude').over(w).alias('corrected_lat'),
                F.last('Longitude').over(w).alias('corrected_long'),
                )
    )