Suppose I have the following dataframe in pyspark:
object | time | has_changed |
---|---|---|
A | 1 | 0 |
A | 2 | 1 |
A | 4 | 0 |
A | 7 | 1 |
B | 2 | 1 |
B | 5 | 0 |
What I want is to add a new column that, for each row, keeps track of the time difference with respect to the last value change for the current object (or first element of the corresponding partition if no value changes exists). For the table I've posted above, the result would be the following:
object | time | has_changed | time_alive |
---|---|---|---|
A | 1 | 0 | 0 |
A | 2 | 1 | 1 |
A | 4 | 0 | 2 |
A | 7 | 1 | 5 |
B | 2 | 1 | 0 |
B | 5 | 0 | 3 |
That is, within each partition by the "object" column, sorted by the "time" column, each value of the corresponding row is calculated as the difference between the time of that row and the previous time at which there is a 1 in the "has_changed" column (if a 1 is not found, the window will scroll to the first element of the partition).
What I would like to implement would be something like the following (pseudo-code):
from pyspark.sql.window import Window as w
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
# Define the data
data = [("A", 1, 0), ("A", 2, 1), ("A", 4, 0), ("A", 7, 1), ("B", 2, 1), ("B", 5, 0)]
# Define the schema
schema = ["object", "time", "has_changed"]
# Create the DataFrame
df = spark.createDataFrame(data, schema)
# Window function (pseudo-code, this won't work)
window = (
w.partitionBy("object")
.orderBy("time")
.rowsBetween(f.when(f.col("has_changed") == 1), w.currentRow)
)
df.withColumn("time_alive", f.col("time") - f.lag("time", 1).over(window))
Create a window specification
W = Window.partitionBy('object').orderBy('time')
Mask the values in time
column where has_changed
is 0
masked = F.when(F.col('has_changed') == 1, F.col('time'))
df = df.withColumn('masked', masked)
# +------+----+-----------+------+
# |object|time|has_changed|masked|
# +------+----+-----------+------+
# | A| 1| 0| NULL|
# | A| 2| 1| 2|
# | A| 4| 0| NULL|
# | A| 7| 1| 7|
# | B| 2| 1| 2|
# | B| 5| 0| NULL|
# +------+----+-----------+------+
Calculate the first value in time
per group
df = df.withColumn('first', F.first('time').over(W))
# +------+----+-----------+------+-----+
# |object|time|has_changed|masked|first|
# +------+----+-----------+------+-----+
# | A| 1| 0| NULL| 1|
# | A| 2| 1| 2| 1|
# | A| 4| 0| NULL| 1|
# | A| 7| 1| 7| 1|
# | B| 2| 1| 2| 2|
# | B| 5| 0| NULL| 2|
# +------+----+-----------+------+-----+
Forward fill and shift the last valid value in masked time column over the window
last_changed = F.lag(F.last('masked', ignorenulls=True).over(W)).over(W)
df = df.withColumn('last_changed', last_changed)
# +------+----+-----------+------+-----+------------+
# |object|time|has_changed|masked|first|last_changed|
# +------+----+-----------+------+-----+------------+
# | A| 1| 0| NULL| 1| NULL|
# | A| 2| 1| 2| 1| NULL|
# | A| 4| 0| NULL| 1| 2|
# | A| 7| 1| 7| 1| 2|
# | B| 2| 1| 2| 2| NULL|
# | B| 5| 0| NULL| 2| 2|
# +------+----+-----------+------+-----+------------+
Fill the nulls in last_changed
with the first value in group
last_changed = F.when(last_changed.isNull(), first).otherwise(last_changed)
df = df.withColumn('last_changed', last_changed)
# +------+----+-----------+------+-----+------------+
# |object|time|has_changed|masked|first|last_changed|
# +------+----+-----------+------+-----+------------+
# | A| 1| 0| NULL| 1| 1|
# | A| 2| 1| 2| 1| 1|
# | A| 4| 0| NULL| 1| 2|
# | A| 7| 1| 7| 1| 2|
# | B| 2| 1| 2| 2| 2|
# | B| 5| 0| NULL| 2| 2|
# +------+----+-----------+------+-----+------------+
Subtract time
column from last_changed
to calculate time_alive
df = df.withColumn('time_alive', F.col('time') - last_changed)
# +------+----+-----------+------+-----+------------+----------+
# |object|time|has_changed|masked|first|last_changed|time_alive|
# +------+----+-----------+------+-----+------------+----------+
# | A| 1| 0| NULL| 1| 1| 0|
# | A| 2| 1| 2| 1| 1| 1|
# | A| 4| 0| NULL| 1| 2| 2|
# | A| 7| 1| 7| 1| 2| 5|
# | B| 2| 1| 2| 2| 2| 0|
# | B| 5| 0| NULL| 2| 2| 3|
# +------+----+-----------+------+-----+------------+----------+