Search code examples
dataframepysparkgroup-bydelete-rowleading-zero

PySpark drop leading zero values by group in dataframe


I have this dataframe -

data = [(0,1,5,5,0,4),
        (1,1,5,6,0,7),
        (2,1,5,7,1,1), 
        (3,1,4,8,1,8), 
        (4,1,5,9,1,1), 
        (5,1,5,10,1,0),
        (6,2,3,4,0,2),
        (7,2,3,5,0,6),
        (8,2,3,6,3,8),
        (9,2,3,7,0,2),
        (10,2,3,8,0,6),
        (11,2,3,9,6,1)
      ]
data_cols = ["id","item","store","week","sales","inventory"]
data_df = spark.createDataFrame(data=data, schema = data_)
display(deptDF)

What I want is to groupby on item, store and week and then delete all rows with leading 0 in sales per group, like so

data_new = [(2,1,5,7,1,1), 
        (3,1,4,8,1,8), 
        (4,1,5,9,1,1), 
        (5,1,5,10,1,0),
        (8,2,3,6,3,8),
        (9,2,3,7,0,2),
        (10,2,3,8,0,6),
        (11,2,3,9,6,1)
      ]
dep_cols = ["id","item","store","week","sales","inventory"]
data_df_new = spark.createDataFrame(data=data_new, schema = dep_cols)
display(data_df_new)

I need to do this in PySpark and I am new to it. Please help!


Solution

  • Use Windowing function, to order by and incremenatlly sum or collect_list.

    1. Filter where sum is greater than 0

    or

    2 filter list wehere has anything above 0. I prefered sum because it is faster.

    w=Window.partitionBy('item','store').orderBy(F.asc('week')).rowsBetween(Window.unboundedPreceding, Window.currentRow)
    
    df.withColumn("sums", F.sum('Sales').over(w)).filter(col('sums')>0).drop('sums').show()
    
    +---+----+-----+----+-----+---+
    | id|item|store|week|sales|inv|
    +---+----+-----+----+-----+---+
    |  2|   1|    5|   7|    1|  1|
    |  3|   1|    5|   8|    1|  8|
    |  4|   1|    5|   9|    1|  1|
    |  5|   1|    5|  10|    1|  0|
    |  8|   2|    3|   6|    3|  8|
    |  9|   2|    3|   7|    0|  2|
    | 10|   2|    3|   8|    0|  6|
    | 11|   2|    3|   9|    6|  1|
    +---+----+-----+----+-----+---+