Search code examples
apache-sparkpyspark

Given a column, x, I wish to count the number of trailing 0s and reset the count every time x is not equal to 0


Take the following data as an example

+---+--------+----------+
| id|column_a|zero_count|
+---+--------+----------+
|  1|       0|         0|
|  2|       0|         0|
|  3|       0|         0|
|  4|       1|         3|
|  5|       0|         0|
|  6|       0|         0|
|  7|       0|         0|
|  8|       0|         0|
|  9|       1|         4|
| 10|       0|         0|
+---+--------+----------+

I wish to get from column_a to column zero_count, i.e. each time column_a != 0, I want to know how many 0s preceded it.


Solution

  • You can do this by using window functions.

    Let's make your dataframe:

    from pyspark.sql.session import SparkSession
    
    spark = SparkSession.builder.getOrCreate()
    df = spark.createDataFrame(
        [
            (1, 0),
            (2, 0),
            (3, 0),
            (4, 1),
            (5, 0),
            (6, 0),
            (7, 0),
            (8, 0),
            (9, 1),
            (10, 0),
        ],
        ["id", "column_a"],
    )
    

    A possible solution looks like this (quite verbose because I'm keeping intermediary results so you can see what happens):

    from pyspark.sql.window import Window
    import pyspark.sql.functions as F
    
    window = Window.orderBy("id")
    df2 = df.select(
        "*",
        F.sum(F.lag("column_a").over(window)).over(window).alias("cumsum"),
        F.coalesce("cumsum", F.col("column_a")).alias("clean"),
    )
    
    >>> df2.show()
    +---+--------+------+-----+
    | id|column_a|cumsum|clean|
    +---+--------+------+-----+
    |  1|       0|  null|    0|
    |  2|       0|     0|    0|
    |  3|       0|     0|    0|
    |  4|       1|     0|    0|
    |  5|       0|     1|    1|
    |  6|       0|     1|    1|
    |  7|       0|     1|    1|
    |  8|       0|     1|    1|
    |  9|       1|     1|    1|
    | 10|       0|     2|    2|
    +---+--------+------+-----+
    
    
    windowspec = Window.orderBy("id").partitionBy("clean")
    df3 = df2.withColumn("row_nr", F.row_number().over(windowspec) - 1)
    
    >>> df3.show()
    +---+--------+------+-----+------+
    | id|column_a|cumsum|clean|row_nr|
    +---+--------+------+-----+------+
    |  1|       0|  null|    0|     0|
    |  2|       0|     0|    0|     1|
    |  3|       0|     0|    0|     2|
    |  4|       1|     0|    0|     3|
    |  5|       0|     1|    1|     0|
    |  6|       0|     1|    1|     1|
    |  7|       0|     1|    1|     2|
    |  8|       0|     1|    1|     3|
    |  9|       1|     1|    1|     4|
    | 10|       0|     2|    2|     0|
    +---+--------+------+-----+------+
    
    
    output = df3.select(
        "id",
        "column_a",
        F.when(F.col("column_a") != 0, F.col("row_nr"))
        .otherwise(F.lit(0))
        .alias("zero_count"),
    )
    
    >>> output.show()
    +---+--------+----------+
    | id|column_a|zero_count|
    +---+--------+----------+
    |  1|       0|         0|
    |  2|       0|         0|
    |  3|       0|         0|
    |  4|       1|         3|
    |  5|       0|         0|
    |  6|       0|         0|
    |  7|       0|         0|
    |  8|       0|         0|
    |  9|       1|         4|
    | 10|       0|         0|
    +---+--------+----------+
    
    

    The general idea is:

    • First we try to create groups which we can partition_by later on. We do that by calculating the cumulative sum (cumsum column). We use the lag function in there because the 1 occurences are part of the previous group of 0 values. Then we clean that cumsum column, that makes df2.
    • Now we have the groups of data! We can use the row_number() function as a kind of "proxy" for the number of zeroes. We just need to do - 1 because the row in which we have 1 does not count as a 0. That makes df3.
    • Creating output is simple: just selecting the rows where column_a was != 0 to be equal to the row number value, else putting it on 0.

    Assumptions:

    • You want to orderBy the id column
    • The way it's done in this example groups all the data onto 1 executor (since we're not using partitionBy on that first window object). This will not work with really big data. If you have really big data, you probably should have some other column on which you can partitionBy.