Search code examples
pythonapache-sparkpysparkapache-spark-sql

GroupBy column and filter rows with maximum value in Pyspark


I am almost certain this has been asked before, but a search through stackoverflow did not answer my question. Not a duplicate of [2] since I want the maximum value, not the most frequent item. I am new to pyspark and trying to do something really simple: I want to groupBy column "A" and then only keep the row of each group that has the maximum value in column "B". Like this:

df_cleaned = df.groupBy("A").agg(F.max("B"))

Unfortunately, this throws away all other columns - df_cleaned only contains the columns "A" and the max value of B. How do I instead keep the rows? ("A", "B", "C"...)


Solution

  • You can do this without a udf using a Window.

    Consider the following example:

    import pyspark.sql.functions as f
    data = [
        ('a', 5),
        ('a', 8),
        ('a', 7),
        ('b', 1),
        ('b', 3)
    ]
    df = sqlCtx.createDataFrame(data, ["A", "B"])
    df.show()
    #+---+---+
    #|  A|  B|
    #+---+---+
    #|  a|  5|
    #|  a|  8|
    #|  a|  7|
    #|  b|  1|
    #|  b|  3|
    #+---+---+
    

    Create a Window to partition by column A and use this to compute the maximum of each group. Then filter out the rows such that the value in column B is equal to the max.

    from pyspark.sql import Window
    w = Window.partitionBy('A')
    df.withColumn('maxB', f.max('B').over(w))\
        .where(f.col('B') == f.col('maxB'))\
        .drop('maxB')\
        .show()
    #+---+---+
    #|  A|  B|
    #+---+---+
    #|  a|  8|
    #|  b|  3|
    #+---+---+
    

    Or equivalently using pyspark-sql:

    df.registerTempTable('table')
    q = "SELECT A, B FROM (SELECT *, MAX(B) OVER (PARTITION BY A) AS maxB FROM table) M WHERE B = maxB"
    sqlCtx.sql(q).show()
    #+---+---+
    #|  A|  B|
    #+---+---+
    #|  b|  3|
    #|  a|  8|
    #+---+---+