Search code examples
apache-sparkpysparkazure-synapse

Pyspark; Count streaks of observations for 1 values


I want to count the number of consecutive streaks of 1 values of a specific station_no. Zero values should be ignored and the next streak of consecutive 1s should be incremented by 1.

For example, the following table:

from pyspark.sql.types import StructType, StructField, StringType, IntegerType

# Given table data
table_data = [
    ("BAN", 0, 0),
    ("BAN", 1, 0),
    ("BAN", 2, 1),
    ("BAN", 3, 1),
    ("BAN", 4, 1),
    ("BAN", 5, 0),
    ("BOZ", 0, 0),
    ("BOZ", 1, 1),
    ("BOZ", 2, 1),
    ("BOZ", 3, 0),
    ("BOZ", 7, 1),
    ("BOZ", 8, 0),
    ("BOZ", 9, 1)
]

# Define the schema for the DataFrame
schema = StructType([
    StructField("station_no", StringType(), True),
    StructField("period_number", IntegerType(), True),
    StructField("group_flag", IntegerType(), True)
])

# Create a DataFrame from the table data and schema
df = spark.createDataFrame(table_data, schema = schema)
df.show()

The table containing the starting data:

+----------+-------------+----------+
|station_no|period_number|group_flag|
+----------+-------------+----------+
|       BAN|            0|         1|
|       BAN|            1|         0|
|       BAN|            2|         1|
|       BAN|            3|         1|
|       BAN|            4|         1|
|       BAN|            5|         0|
|       BAN|            6|         1|
|       BOZ|            0|         0|
|       BOZ|            1|         1|
|       BOZ|            2|         1|
|       BOZ|            3|         0|
|       BOZ|            4|         1|
|       BOZ|            5|         0|
|       BOZ|            6|         1|
+----------+-------------+----------+

The result should look like this:

+----------+-------------+----------+
|station_no|period_number|group_flag|
+----------+-------------+----------+
|       BAN|            0|         1|
|       BAN|            1|         0|
|       BAN|            2|         2|
|       BAN|            3|         2|
|       BAN|            4|         2|
|       BAN|            5|         0|
|       BAN|            6|         3|
|       BOZ|            0|         0|
|       BOZ|            1|         1|
|       BOZ|            2|         1|
|       BOZ|            3|         0|
|       BOZ|            4|         2|
|       BOZ|            5|         0|
|       BOZ|            6|         3|
+----------+-------------+----------+

If have tried some window and ranking functions but I am not albe to figure it out.


Solution

  • I separated it to multiple columns to make it easier to understand, but you can get it all in the original column like your example

    First calculate on which row we see a change from 0 to 1, using a lag function

    Then sum the number of changes from all previous rows up to the current rows

    lag_w = Window.partitionBy('station_no').orderBy(F.asc('period_number'))
    sum_w = Window.partitionBy('station_no').orderBy(F.asc('period_number')).rowsBetween(Window.unboundedPreceding, Window.currentRow)
    
    df.withColumn(
        'changed_from_previous',
        (F.col('group_flag')==1) & (F.coalesce(F.lag('group_flag').over(lag_w), F.lit(0))==0)
    ).withColumn(
        'total_changes',
        F.when(
            F.col('group_flag')==1,
            F.sum(F.col('changed_from_previous').cast('int')).over(sum_w)
        ).otherwise(F.lit(0))
    ).show()
    
    +----------+-------------+----------+---------------------+-------------+
    |station_no|period_number|group_flag|changed_from_previous|total_changes|
    +----------+-------------+----------+---------------------+-------------+
    |       BAN|            0|         1|                 true|            1|
    |       BAN|            1|         0|                false|            0|
    |       BAN|            2|         1|                 true|            2|
    |       BAN|            3|         1|                false|            2|
    |       BAN|            4|         1|                false|            2|
    |       BAN|            5|         0|                false|            0|
    |       BAN|            6|         1|                 true|            3|
    |       BOZ|            0|         0|                false|            0|
    |       BOZ|            1|         1|                 true|            1|
    |       BOZ|            2|         1|                false|            1|
    |       BOZ|            3|         0|                false|            0|
    |       BOZ|            4|         1|                 true|            2|
    |       BOZ|            5|         0|                false|            0|
    |       BOZ|            6|         1|                 true|            3|
    +----------+-------------+----------+---------------------+-------------+