I have a dataframe (df
) with below structure:
Grid_ID, Latitude, Longitude, DateTimeStamp
Input:
Grid_ID Latitude Longitude DateTimeStamp
Grid_1 Lat1 Long1 2021-06-30 00:00:00
Grid_1 Lat1 Long1 2021-06-30 00:01:00
Grid_1 Lat1 Long1 2021-06-30 00:02:00
Grid_1 Lat2 Long2 2021-07-01 00:00:00
Grid_1 Lat2 Long2 2021-07-01 00:01:00
Grid_1 Lat2 Long2 2021-07-01 00:02:00
and Grid_ID has two sets of lat and long that are mutually exclusive based upon Date column. i.e. when Date <= 06/30/2021, Latitude/Longitude = Lat1/Long1 and when Date > 06/30/2021, Latitude/Longitude = Lat2/Long2, respectively
I need to create new columns (Corrected_Lat and Corrected_Long) and assign Lat2/Long2
I am using groupBy and agg as below to do the above:
df_dated = df.withColumn("date", F.to_date("DateTimeStamp")) \
.filter(F.col("date") == "2021-07-01") \
.groupBy("Grid_ID") \
.agg(F.collect_set("Latitude").getItem(0).cast("float").alias("corrected_lat"),
F.collect_set("Longitude").getItem(0).cast("float").alias("corrected_long")) \
.withColumnRenamed("Grid_ID", "Grid_ID_dated") \
.select("Grid_ID_dated", "corrected_lat", "corrected_long")
df_final = df.join(df_dated, on=[df.Grid_ID == df_dated.Grid_ID_dated],
how="inner") \
.select(*df.columns, "corrected_lat", "corrected_long")
Output:
Grid_ID Latitude Longitude DateTimeStamp corrected_lat corrected_long
Grid_1 Lat1 Long1 2021-06-30 00:00:00 Lat2 Long2
Grid_1 Lat1 Long1 2021-06-30 00:01:00 Lat2 Long2
Grid_1 Lat1 Long1 2021-06-30 00:02:00 Lat2 Long2
Grid_1 Lat2 Long2 2021-07-01 00:00:00 Lat2 Long2
Grid_1 Lat2 Long2 2021-07-01 00:01:00 Lat2 Long2
Grid_1 Lat2 Long2 2021-07-01 00:02:00 Lat2 Long2
But I am wondering if windows function can be used here and would it be faster than first approach using groupBy and agg?
Any other approach that is faster is certainly appreciated.
This will apply the latest Latitude/Longitude value to all rows (per Grid_ID):
w = (
Window.partitionBy('Grid_ID')
.orderBy("DateTimeStamp")
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
)
df_final = (
df.withColumn("date", F.to_date("DateTimeStamp"))
.select(*df.columns,
F.last('Latitude').over(w).alias('corrected_lat'),
F.last('Longitude').over(w).alias('corrected_long'),
)
)