Search code examples
pythonpandaspyspark

How to iterate over a pyspark dataframe to increment a value and reset it to 0


I have a pyspark dataframe with the following fields: dt = timestamp (one row per hour) rain_1h = rain in mm that hour

Now I need to calculate the number of Dry Hours, when it doesnt rain, so it must increment it when it doesnt rain, and when it rains it should reset to 0.

I tried the following:

def calculate_dry_hours(df):
    window_spec = Window.orderBy("dt")
    
    # Create a column "RainyHour" to flag rainy hours
    df = df.withColumn('DryHour', when(col('rain_1h') == 0, 1).otherwise(0))
    df = df.withColumn('RainHour', when(col('rain_1h') > 0, 1).otherwise(0))  

    # Create a column "lag_rain1h" for the lag of "rain_1h"
    df = df.withColumn("lag_rain1h", lag("rain_1h").over(window_spec))
    dry_hour_window = Window.partitionBy().orderBy('dt')

    df = df.withColumn('DryHourCount', when(col('RainHour') == 0, sum('DryHour').over(dry_hour_window)).otherwise(0))
    df = df.withColumn('DryHourCount', when(col('RainHour') == 0 & lag('RainHour', 1).over(dry_hour_window) == 1, 0).otherwise(col('DryHourCount')))
   
    return df

However this is not giving the desired results


Solution

  • For all rows count how many hours before that row had 0 mm rain. This gives a 'group id' and within this group we can simply count the rows.

    Create some testdata for 1 day:

    from datetime import datetime, timedelta
    from random import seed, randint
    seed(42)
    data=[(datetime.fromisoformat('2023-01-01T00:00:00') + timedelta(hours=h), randint(0,1)) for h in range(0,24) ]
    df=spark.createDataFrame(data, ['dt', 'rain_1h'])
    df.orderBy('dt').show()
    

    Output:

    +-------------------+-------+
    |                 dt|rain_1h|
    +-------------------+-------+
    |2023-01-01 00:00:00|      0|
    |2023-01-01 01:00:00|      0|
    |2023-01-01 02:00:00|      1|
    |2023-01-01 03:00:00|      0|
    |2023-01-01 04:00:00|      0|
    |2023-01-01 05:00:00|      0|
    |2023-01-01 06:00:00|      0|
    |2023-01-01 07:00:00|      0|
    |2023-01-01 08:00:00|      1|
    |2023-01-01 09:00:00|      0|
    [...]
    

    Now calculate the 'group ids' and the row number within each group:

    from pyspark.sql import functions as F
    from pyspark.sql import Window
    
    group_window = Window.orderBy('dt')
    hour_window = Window.partitionBy('dry_group').orderBy('dt')
    
    df.withColumn('dry_group', F.sum((F.col('rain_1h')!=0).cast('int')).over(group_window)) \
      .withColumn('dry_hours_within_group', F.row_number().over(hour_window) - 1) \
      .show(n=30)
    

    Output:

    +-------------------+-------+---------+----------------------+
    |                 dt|rain_1h|dry_group|dry_hours_within_group|
    +-------------------+-------+---------+----------------------+
    |2023-01-01 00:00:00|      0|        0|                     0|
    |2023-01-01 01:00:00|      0|        0|                     1|
    |2023-01-01 02:00:00|      1|        1|                     0|
    |2023-01-01 03:00:00|      0|        1|                     1|
    |2023-01-01 04:00:00|      0|        1|                     2|
    |2023-01-01 05:00:00|      0|        1|                     3|
    |2023-01-01 06:00:00|      0|        1|                     4|
    |2023-01-01 07:00:00|      0|        1|                     5|
    |2023-01-01 08:00:00|      1|        2|                     0|
    |2023-01-01 09:00:00|      0|        2|                     1|
    |2023-01-01 10:00:00|      0|        2|                     2|
    |2023-01-01 11:00:00|      0|        2|                     3|
    |2023-01-01 12:00:00|      0|        2|                     4|
    |2023-01-01 13:00:00|      0|        2|                     5|
    |2023-01-01 14:00:00|      0|        2|                     6|
    |2023-01-01 15:00:00|      0|        2|                     7|
    |2023-01-01 16:00:00|      1|        3|                     0|
    |2023-01-01 17:00:00|      0|        3|                     1|
    |2023-01-01 18:00:00|      1|        4|                     0|
    |2023-01-01 19:00:00|      1|        5|                     0|
    |2023-01-01 20:00:00|      0|        5|                     1|
    |2023-01-01 21:00:00|      0|        5|                     2|
    |2023-01-01 22:00:00|      1|        6|                     0|
    |2023-01-01 23:00:00|      1|        7|                     0|
    +-------------------+-------+---------+----------------------+
    

    Remark: Spark will print out a warning

    WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.

    for the first window group_window. Using Spark for iterating over a complete dataframe is not ideal, as Spark moves all data to a single executor. There are some answers on this topic (for example here or here) on SO.