Search code examples
apache-sparkpysparkgroup-bywindow-functions

Adding value to column if certain condition meets per group in spark


I have something which is fairly simple I guess.

What I try to achieve is per group, give an increase number (rank?) if a certain condition is met. For each group, it starts with 1, if condition is met, next rows are value of previous row +1. This goes further and further within the group, each time the condition is met, add 1.

Table below might show it more clearly. (What I try to create is column 'what_i_want')

group   to_add_number   what_i_want
aaaaaa  0                 1
aaaaaa  0                 1
aaaaaa  1                 2
aaaaaa  0                 2
aaaaaa  0                 2
aaaaaa  1                 3
aaaaaa  0                 3
aaaaaa  0                 3
bbbbbb  0                 1
bbbbbb  1                 2
bbbbbb  1                 3
bbbbbb  0                 3
cccccc  0                 1
cccccc  0                 1
cccccc  0                 1
cccccc  1                 2

I think a window function (lag) might do it, but I can't get there.

What i tried is:

from pyspark.sql.functions import lit,when,lag,row_number
from pyspark.sql.window import Window

windowSpec=Window.partitionBy('group')
df=df.withColumn('tmp_rnk',lit(1))
df=df.withColumn('what_i_want',when(col('to_add_number')==0,lag('tmp_rnk').over(windowSpec)).otherwise(col('what_i_want')+1)

or

df=df.withColumn('tmp_rnk',lit(1))
df=df.withColumn('row_number_rank',row_number().over(windowSpec))
df=df.withColumn('what_i_want',when((col('to_add_number')==0)&(col('row_number_rank')==1)
,lit(1)
.when(col('to_add_number')==0)&(col('row_number_rank')>1),lag('what_i_want').over(windowSpec).otherwise(col('what_i_want')+1)

I tried several variations, searched for on stackoverflow on terms of 'conditional windowfunctions', 'lag, lead....), but nothing worked or i didn't find a duplicate question.


Solution

  • To get column what_i_want, you can run an incremental sum on to_add_number with an orderby column (order_id).

    from pyspark.sql import functions as F
    from pyspark.sql.window import Window
    
    df.withColumn("order_id", F.monotonically_increasing_id())\
      .withColumn("what_i_want", F.sum("to_add_number").over(Window().partitionBy("group").orderBy("order_id"))+1)\
      .orderBy("order_id").drop("order_id").show()
    
    
    #+------+-------------+-----------+
    #| group|to_add_number|what_i_want|
    #+------+-------------+-----------+
    #|aaaaaa|            0|          1|
    #|aaaaaa|            0|          1|
    #|aaaaaa|            1|          2|
    #|aaaaaa|            0|          2|
    #|aaaaaa|            0|          2|
    #|aaaaaa|            1|          3|
    #|aaaaaa|            0|          3|
    #|aaaaaa|            0|          3|
    #|bbbbbb|            0|          1|
    #|bbbbbb|            1|          2|
    #|bbbbbb|            1|          3|
    #|bbbbbb|            0|          3|
    #|cccccc|            0|          1|
    #|cccccc|            0|          1|
    #|cccccc|            0|          1|
    #|cccccc|            1|          2|
    #+------+-------------+-----------+