Take the following data as an example
+---+--------+----------+
| id|column_a|zero_count|
+---+--------+----------+
| 1| 0| 0|
| 2| 0| 0|
| 3| 0| 0|
| 4| 1| 3|
| 5| 0| 0|
| 6| 0| 0|
| 7| 0| 0|
| 8| 0| 0|
| 9| 1| 4|
| 10| 0| 0|
+---+--------+----------+
I wish to get from column_a
to column zero_count
, i.e. each time column_a != 0
, I want to know how many 0s preceded it.
You can do this by using window functions.
Let's make your dataframe:
from pyspark.sql.session import SparkSession
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame(
[
(1, 0),
(2, 0),
(3, 0),
(4, 1),
(5, 0),
(6, 0),
(7, 0),
(8, 0),
(9, 1),
(10, 0),
],
["id", "column_a"],
)
A possible solution looks like this (quite verbose because I'm keeping intermediary results so you can see what happens):
from pyspark.sql.window import Window
import pyspark.sql.functions as F
window = Window.orderBy("id")
df2 = df.select(
"*",
F.sum(F.lag("column_a").over(window)).over(window).alias("cumsum"),
F.coalesce("cumsum", F.col("column_a")).alias("clean"),
)
>>> df2.show()
+---+--------+------+-----+
| id|column_a|cumsum|clean|
+---+--------+------+-----+
| 1| 0| null| 0|
| 2| 0| 0| 0|
| 3| 0| 0| 0|
| 4| 1| 0| 0|
| 5| 0| 1| 1|
| 6| 0| 1| 1|
| 7| 0| 1| 1|
| 8| 0| 1| 1|
| 9| 1| 1| 1|
| 10| 0| 2| 2|
+---+--------+------+-----+
windowspec = Window.orderBy("id").partitionBy("clean")
df3 = df2.withColumn("row_nr", F.row_number().over(windowspec) - 1)
>>> df3.show()
+---+--------+------+-----+------+
| id|column_a|cumsum|clean|row_nr|
+---+--------+------+-----+------+
| 1| 0| null| 0| 0|
| 2| 0| 0| 0| 1|
| 3| 0| 0| 0| 2|
| 4| 1| 0| 0| 3|
| 5| 0| 1| 1| 0|
| 6| 0| 1| 1| 1|
| 7| 0| 1| 1| 2|
| 8| 0| 1| 1| 3|
| 9| 1| 1| 1| 4|
| 10| 0| 2| 2| 0|
+---+--------+------+-----+------+
output = df3.select(
"id",
"column_a",
F.when(F.col("column_a") != 0, F.col("row_nr"))
.otherwise(F.lit(0))
.alias("zero_count"),
)
>>> output.show()
+---+--------+----------+
| id|column_a|zero_count|
+---+--------+----------+
| 1| 0| 0|
| 2| 0| 0|
| 3| 0| 0|
| 4| 1| 3|
| 5| 0| 0|
| 6| 0| 0|
| 7| 0| 0|
| 8| 0| 0|
| 9| 1| 4|
| 10| 0| 0|
+---+--------+----------+
The general idea is:
partition_by
later on. We do that by calculating the cumulative sum (cumsum
column). We use the lag
function in there because the 1 occurences are part of the previous group of 0 values. Then we clean that cumsum
column, that makes df2
.row_number()
function as a kind of "proxy" for the number of zeroes. We just need to do - 1
because the row in which we have 1 does not count as a 0. That makes df3
.output
is simple: just selecting the rows where column_a
was != 0 to be equal to the row number value, else putting it on 0.Assumptions:
id
columnpartitionBy
on that first window
object). This will not work with really big data. If you have really big data, you probably should have some other column on which you can partitionBy
.