Search code examples
dataframeapache-sparkpyspark

Get row number only for filtered rows in PySpark


I have the following dataframe:

id_cnt id_prd type price
1 A SS 10
2 A AA 20
3 A AA 25
1 B AA 55
2 B SS 50
3 B AA 75
4 B AA 80

I need to add a new column: rownumber. For each id_prd, order by price and get row number, but only when type = "AA".

Expected output:

id_cnt id_prd type price rownumber
1 A SS 10 null
2 A AA 20 2
3 A AA 25 1
1 B AA 55 3
2 B SS 50 null
3 B AA 75 2
4 B AA 80 1

Solution

  • You can use a Window and partition by both id_prd and type and order by price, then get the row_number only for the case you have type = AA

    from pyspark.sql.functions import row_number, desc, col, lit, when
    from pyspark.sql.window import Window
    
    data = [
        (1, "A", "SS", 10),
        (2, "A", "AA", 20),
        (3, "A", "AA", 25),
        (1, "B", "AA", 55),
        (2, "B", "SS", 50),
        (3, "B", "AA", 75),
        (4, "B", "AA", 80)]
    
    
    df = spark.createDataFrame(data, ["id_cnt", "id_prd", "type", "price"])
                
    window = Window.partitionBy("id_prd", "type").orderBy(desc("price"))
    
    df.withColumn("row_number", when(col("type") == lit("AA"), row_number().over(window)).otherwise(None)).show()
    
    +------+------+----+-----+----------+                                           
    |id_cnt|id_prd|type|price|row_number|
    +------+------+----+-----+----------+
    |     3|     A|  AA|   25|         1|
    |     2|     A|  AA|   20|         2|
    |     1|     A|  SS|   10|      null|
    |     4|     B|  AA|   80|         1|
    |     3|     B|  AA|   75|         2|
    |     1|     B|  AA|   55|         3|
    |     2|     B|  SS|   50|      null|
    +------+------+----+-----+----------+