Search code examples
pythonpyspark

Rank on a subset of a partition - PySpark


The code snippet below creates the column 'rank' with a condition. I want to perform the rank based on a subset of the partition, hence I use a when clause and set category=='Y' and then execute the rank. However, I did not expect the result below. Where I expected rank=1 it is in fact rank=2.

How can I achieve to do a rank on a subset of a partition?

import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql import Row

data = [
    Row(id=1, code=14, category='N'),
    Row(id=1, code=20, category='Y'),
    Row(id=1,  code=19, category='Y'),
    Row(id=1,  code=22, category='Y'),
    Row(id=1,  code=15, category='Y'),
]

ps_df = spark.createDataFrame(data)

window = Window.partitionBy('id').orderBy('code')

ps_df = ps_df.withColumn('rank', F.when(col('category')=='Y', F.rank().over(window)))

ps_df.show()
+---+----+--------+----+
| id|code|category|rank|
+---+----+--------+----+
|  1|  14|       N|NULL|
|  1|  15|       Y|   2|
|  1|  19|       Y|   3|
|  1|  20|       Y|   4|
|  1|  22|       Y|   5|
+---+----+--------+----+

Solution

  • I think we can use a alternative way, which don't need to use rank(), to achieve the same goal:

    ps_df = ps_df.withColumn(
        "flag", func.when(func.col("category")=="Y", func.lit(1)).otherwise(func.lit(0))
    ).withColumn(
        "cumsum", func.sum("flag").over(Window.partitionBy("id").orderBy("code"))
    ).withColumn(
        "rank", func.when(func.col("category")=="Y", func.col("cumsum")).otherwise(func.lit(None))
    ).select(
        "id", "code", "category", "rank"
    )
    

    First, you raise a value equal to 1 flag if this is the partition or group that you want to calculate. Then use a sum() with window function to do the cumulative sum of that partition to perform the ranking.