Search code examples
pysparkapache-spark-sqlaggregate-functionsscalar

How do you use aggregated values within PySpark SQL when() clause?


I am trying to learn PySpark, and have tried to learn how to use SQL when() clauses to better categorize my data. (See here: https://sparkbyexamples.com/spark/spark-case-when-otherwise-example/) What I can't seem to get addressed is how to insert actual scalar values into the when() conditions for comparison's sake explicitly. It seems the aggregate functions return more tabular values than actual float() types.
I keep getting this error message unsupported operand type(s) for -: 'method' and 'method'
When I tried running functions to aggregate another column in the original data frame I noticed the result didn't seem to be a flat scaler as much as a table (agg(select(f.stddev("Col")) gives a result like: "DataFrame[stddev_samp(TAXI_OUT): double]") Here is a sample of what I am trying to accomplish if you want to replicate, and I was wondering how you might get aggregate values like the standard deviation and mean within the when() clause so you can use that to categorize your new column:

samp = spark.createDataFrame(
    [("A","A1",4,1.25),("B","B3",3,2.14),("C","C2",7,4.24),("A","A3",4,1.25),("B","B1",3,2.14),("C","C1",7,4.24)],
    ["Category","Sub-cat","quantity","cost"])
  
    psMean = samp.agg({'quantity':'mean'})
    psStDev = samp.agg({'quantity':'stddev'})

    psCatVect = samp.withColumn('quant_category',.when(samp['quantity']<=(psMean-psStDev),'small').otherwise('not small')) ```  


Solution

  • psMean and psStdev in your example are dataframes, you need to use collect() method to extract the scalar values

    psMean = samp.agg({'quantity':'mean'}).collect()[0][0]
    psStDev = samp.agg({'quantity':'stddev'}).collect()[0][0]