Search code examples
dataframeapache-sparkpysparkapache-spark-sqltop-n

Select rows from Spark DataFrame based on a condition


I have two Spark dataframes:

df1
+---+----+
| id| var|  
+---+----+
|323| [a]|
+---+----+

df2
+----+----------+----------+
| src| str_value| num_value| 
+----+----------+----------+
| [a]|     ghn12|      0.0 |
+----+----------+----------+
| [a]|     54fdg|      1.2 |
+----+----------+----------+
| [a]|     90okl|      0.7 |
+----+----------+----------+
| [b]|     jh456|      0.5 |
+----+----------+----------+
| [a]|     ghn12|      0.2 |
+----+----------+----------+
| [c]|     ghn12|      0.7 |
+----+----------+----------+

I need to return top 3 rows from df2 dataframe where df1.var == df2.src and df2.num_value has the smallest value. So, desired output is (sorted by num_value):

+----+----------+----------+
| src| str_value| num_value| 
+----+----------+----------+
| [a]|     ghn12|      0.0 |
+----+----------+----------+
| [a]|     ghn12|      0.2 |
+----+----------+----------+
| [a]|     90okl|      0.7 |
+----+----------+----------+

I know how to implement this using SQL, but I have some difficulties with PySpark/Spark SQL.


Solution

  • I would do it using dense_rank window function.

    from pyspark.sql import functions as F, Window as W
    
    w = W.partitionBy('src').orderBy('num_value')
    df3 = (
        df2
        .join(df1, df2.src == df1.var, 'semi')
        .withColumn('_rank', F.dense_rank().over(w))
        .filter('_rank <= 3')
        .drop('_rank')
    )