Search code examples
apache-sparkpysparkgroup-byapache-spark-sqlmax

Groupby and return the row label of the maximum value in PySpark Dataframe


I have dataframe:

data = [('I ran home', 3, 1, 10), 
       ('I went home', 3, 1, 11),
       ('I looked at the cat', 4, 2, 19),
       ('The cat looked at me', 5, 3, 20),
       ('I ran home', 3, 4, 10),
       ('I went homes', 3, 4, 12)]

schema = StructType([ \
    StructField("text",StringType(),True), \
    StructField("word_count", IntegerType(), True), \
    StructField("group", IntegerType(), True), \
    StructField("len_text", IntegerType(), True)])

 
df = spark.createDataFrame(data=data, schema=schema)
df.show(truncate=False)
+--------------------+----------+-----+--------+
|text                |word_count|group|len_text|
+--------------------+----------+-----+--------+
|I ran home          |3         |1    |10      |
|I went home         |3         |1    |11      |
|I looked at the cat |4         |2    |19      |
|The cat looked at me|5         |3    |20      |
|I ran home          |3         |4    |10      |
|I went homes        |3         |4    |12      |
+--------------------+----------+-----+--------+

I want to filter rows with two conditions: if the values in the word_count column are the same and if the value in the len_text column is greater than the next row, then leave the greater value. In pandas i can do this with idmax():

df1 = df.loc[df.groupby('group')['len_text'].idxmax()]

Is there any analogue for pyspark? I want this result:

+--------------------+----------+-----+--------+
|text                |word_count|group|len_text|
+--------------------+----------+-----+--------+
|I went home         |3         |1    |11      |
|I looked at the cat |4         |2    |19      |
|The cat looked at me|5         |3    |20      |
|I went homes        |3         |4    |12      |
+--------------------+----------+-----+--------+

Solution

  • You can use window functions, i.e. row_number

    from pyspark.sql import functions as F, Window as W
    
    w = W.partitionBy('group').orderBy(F.desc('len_text'))
    df = df.withColumn('_rn', F.row_number().over(w))
    df = df.filter('_rn=1').drop('_rn')
    
    df.show()
    # +--------------------+----------+-----+--------+
    # |                text|word_count|group|len_text|
    # +--------------------+----------+-----+--------+
    # |         I went home|         3|    1|      11|
    # | I looked at the cat|         4|    2|      19|
    # |The cat looked at me|         5|    3|      20|
    # |        I went homes|         3|    4|      12|
    # +--------------------+----------+-----+--------+