I want to count the number of consecutive streaks of 1 values of a specific station_no. Zero values should be ignored and the next streak of consecutive 1s should be incremented by 1.
For example, the following table:
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
# Given table data
table_data = [
("BAN", 0, 0),
("BAN", 1, 0),
("BAN", 2, 1),
("BAN", 3, 1),
("BAN", 4, 1),
("BAN", 5, 0),
("BOZ", 0, 0),
("BOZ", 1, 1),
("BOZ", 2, 1),
("BOZ", 3, 0),
("BOZ", 7, 1),
("BOZ", 8, 0),
("BOZ", 9, 1)
]
# Define the schema for the DataFrame
schema = StructType([
StructField("station_no", StringType(), True),
StructField("period_number", IntegerType(), True),
StructField("group_flag", IntegerType(), True)
])
# Create a DataFrame from the table data and schema
df = spark.createDataFrame(table_data, schema = schema)
df.show()
The table containing the starting data:
+----------+-------------+----------+
|station_no|period_number|group_flag|
+----------+-------------+----------+
| BAN| 0| 1|
| BAN| 1| 0|
| BAN| 2| 1|
| BAN| 3| 1|
| BAN| 4| 1|
| BAN| 5| 0|
| BAN| 6| 1|
| BOZ| 0| 0|
| BOZ| 1| 1|
| BOZ| 2| 1|
| BOZ| 3| 0|
| BOZ| 4| 1|
| BOZ| 5| 0|
| BOZ| 6| 1|
+----------+-------------+----------+
The result should look like this:
+----------+-------------+----------+
|station_no|period_number|group_flag|
+----------+-------------+----------+
| BAN| 0| 1|
| BAN| 1| 0|
| BAN| 2| 2|
| BAN| 3| 2|
| BAN| 4| 2|
| BAN| 5| 0|
| BAN| 6| 3|
| BOZ| 0| 0|
| BOZ| 1| 1|
| BOZ| 2| 1|
| BOZ| 3| 0|
| BOZ| 4| 2|
| BOZ| 5| 0|
| BOZ| 6| 3|
+----------+-------------+----------+
If have tried some window and ranking functions but I am not albe to figure it out.
I separated it to multiple columns to make it easier to understand, but you can get it all in the original column like your example
First calculate on which row we see a change from 0 to 1, using a lag function
Then sum the number of changes from all previous rows up to the current rows
lag_w = Window.partitionBy('station_no').orderBy(F.asc('period_number'))
sum_w = Window.partitionBy('station_no').orderBy(F.asc('period_number')).rowsBetween(Window.unboundedPreceding, Window.currentRow)
df.withColumn(
'changed_from_previous',
(F.col('group_flag')==1) & (F.coalesce(F.lag('group_flag').over(lag_w), F.lit(0))==0)
).withColumn(
'total_changes',
F.when(
F.col('group_flag')==1,
F.sum(F.col('changed_from_previous').cast('int')).over(sum_w)
).otherwise(F.lit(0))
).show()
+----------+-------------+----------+---------------------+-------------+
|station_no|period_number|group_flag|changed_from_previous|total_changes|
+----------+-------------+----------+---------------------+-------------+
| BAN| 0| 1| true| 1|
| BAN| 1| 0| false| 0|
| BAN| 2| 1| true| 2|
| BAN| 3| 1| false| 2|
| BAN| 4| 1| false| 2|
| BAN| 5| 0| false| 0|
| BAN| 6| 1| true| 3|
| BOZ| 0| 0| false| 0|
| BOZ| 1| 1| true| 1|
| BOZ| 2| 1| false| 1|
| BOZ| 3| 0| false| 0|
| BOZ| 4| 1| true| 2|
| BOZ| 5| 0| false| 0|
| BOZ| 6| 1| true| 3|
+----------+-------------+----------+---------------------+-------------+