Search code examples
dataframeapache-sparkpysparkapache-spark-sqlwindow-functions

PySpark Window function on entire data frame


Consider a PySpark data frame. I would like to summarize the entire data frame, per column, and append the result for every row.

+-----+----------+-----------+
|index|      col1| col2      |
+-----+----------+-----------+
|  0.0|0.58734024|0.085703015|
|  1.0|0.67304325| 0.17850411|

Expected result

+-----+----------+-----------+-----------+-----------+-----------+-----------+
|index|      col1| col2      |  col1_min | col1_mean |col2_min   | col2_mean
+-----+----------+-----------+-----------+-----------+-----------+-----------+
|  0.0|0.58734024|0.085703015|  -5       | 2.3       |  -2       | 1.4 |
|  1.0|0.67304325| 0.17850411|  -5       | 2.3       |  -2       | 1.4 |

To my knowledge, I'll need Window function with the whole data frame as Window, to keep the result for each row (instead of, for example, do the stats separately then join back to replicate for each row)

My questions are:

  1. How to write Window without any partition nor order by?

    I know there is the standard Window with Partition and Order, but not the one taking everything as 1 single partition

    w = Window.partitionBy("col1", "col2").orderBy(desc("col1"))
    df = df.withColumn("col1_mean", mean("col1").over(w)))
    

    How would I write a Window with everything as one partition?

  2. Any way to write dynamically for all columns?

    Let's say I have 500 columns, it does not look great to write repeatedly.

    df = (df
        .withColumn("col1_mean", mean("col1").over(w)))
        .withColumn("col1_min", min("col2").over(w))
        .withColumn("col2_mean", mean().over(w))
        .....
    )
    

    Let's assume I want multiple stats for each column, so each colx will spawn colx_min, colx_max, colx_mean.


Solution

  • Instead of using window you can achieve the same with a custom aggregation in combination with cross join:

    import pyspark.sql.functions as F
    from pyspark.sql.functions import broadcast
    from itertools import chain
    
    df = spark.createDataFrame([
      [1, 2.3, 1],
      [2, 5.3, 2],
      [3, 2.1, 4],
      [4, 1.5, 5]
    ], ["index", "col1", "col2"])
    
    agg_cols = [(
                 F.min(c).alias("min_" + c), 
                 F.max(c).alias("max_" + c), 
                 F.mean(c).alias("mean_" + c)) 
    
      for c in df.columns if c.startswith('col')]
    
    stats_df = df.agg(*list(chain(*agg_cols)))
    
    # there is no performance impact from crossJoin since we have only one row on the right table which we broadcast (most likely Spark will broadcast it anyway)
    df.crossJoin(broadcast(stats_df)).show() 
    
    # +-----+----+----+--------+--------+---------+--------+--------+---------+
    # |index|col1|col2|min_col1|max_col1|mean_col1|min_col2|max_col2|mean_col2|
    # +-----+----+----+--------+--------+---------+--------+--------+---------+
    # |    1| 2.3|   1|     1.5|     5.3|      2.8|       1|       5|      3.0|
    # |    2| 5.3|   2|     1.5|     5.3|      2.8|       1|       5|      3.0|
    # |    3| 2.1|   4|     1.5|     5.3|      2.8|       1|       5|      3.0|
    # |    4| 1.5|   5|     1.5|     5.3|      2.8|       1|       5|      3.0|
    # +-----+----+----+--------+--------+---------+--------+--------+---------+
    

    Note1: Using broadcast we will avoid shuffling since the broadcasted df will be send to all the executors.

    Note2: with chain(*agg_cols) we flatten the list of tuples which we created in the previous step.

    UPDATE:

    Here is the execution plan for the above program:

    == Physical Plan ==
    *(3) BroadcastNestedLoopJoin BuildRight, Cross
    :- *(3) Scan ExistingRDD[index#196L,col1#197,col2#198L]
    +- BroadcastExchange IdentityBroadcastMode, [id=#274]
       +- *(2) HashAggregate(keys=[], functions=[finalmerge_min(merge min#233) AS min(col1#197)#202, finalmerge_max(merge max#235) AS max(col1#197)#204, finalmerge_avg(merge sum#238, count#239L) AS avg(col1#197)#206, finalmerge_min(merge min#241L) AS min(col2#198L)#208L, finalmerge_max(merge max#243L) AS max(col2#198L)#210L, finalmerge_avg(merge sum#246, count#247L) AS avg(col2#198L)#212])
          +- Exchange SinglePartition, [id=#270]
             +- *(1) HashAggregate(keys=[], functions=[partial_min(col1#197) AS min#233, partial_max(col1#197) AS max#235, partial_avg(col1#197) AS (sum#238, count#239L), partial_min(col2#198L) AS min#241L, partial_max(col2#198L) AS max#243L, partial_avg(col2#198L) AS (sum#246, count#247L)])
                +- *(1) Project [col1#197, col2#198L]
                   +- *(1) Scan ExistingRDD[index#196L,col1#197,col2#198L]
    

    Here we see a BroadcastExchange of a SinglePartition which is broadcasting one single row since stats_df can fit into a SinglePartition. Therefore the data being shuffled here is only one row (the minimum possible).