Search code examples
pysparkmultiple-columnsaggregate-functionsoutliers

pyspark agg without outliers for each columns


I am facing a problem that use agg function to calculate statistics without outliers for multiple columns. I need to remove 25 percentile and 75 percentile for "each column" and calculate min, max, mean.

The input table:

df = spark.createDataFrame(
    [
        ('a', 'E1', 'C1', 1, 1, 1),
        ('a', 'E1', 'C1', 2, 12, 22),
        ('a', 'E1', 'C1', 3, 13, 23),
        ('a', 'E1', 'C1', 4, 133, 123),
        ('b', 'E1', 'C1', 1, 1, 2),
        ('b', 'E1', 'C1', 2, 15, 25),
        ('b', 'E1', 'C1', 3, 56, 126),
        ('b', 'E1', 'C1', 4, 156, 126),
    ],
    schema=['sheet', 'equipment', 'chamber', 'time', 'value1', 'value2']
)

df.printSchema()
df.show(10, False)

+-----+---------+-------+----+------+------+
|sheet|equipment|chamber|time|value1|value2|
+-----+---------+-------+----+------+------+
|a    |E1       |C1     |1   |1     |1     |
|a    |E1       |C1     |2   |12    |22    |
|a    |E1       |C1     |3   |13    |23    |
|a    |E1       |C1     |4   |133   |123   |
|b    |E1       |C1     |1   |1     |2     |
|b    |E1       |C1     |2   |15    |25    |
|b    |E1       |C1     |3   |16    |26   |
|b    |E1       |C1     |4   |156   |126   |
+-----+---------+-------+----+------+------+

The expected result:

sheet equipment chamber value1_min value1_max value1_mean value2_min value2_max value2_mean
a E1 C1 12 13 12.5 22 23 23.5
b E1 C1 15 16 15.5 25 26 25.5

Here is my code so far,

but it takes number of columns times for loops, is there more efficient expression for this problem?

    groupby_list = ["sheet_id"]
    dummy_origin = df.select(groupby_list).dropDuplicates(groupby_list)
    w = W.Window.partitionBy(groupby_list)

    param_df = df.drop(*groupby_list,'equipment', 'chamber','time')
    
    for col_name in param_df.columns:
        # for each column compute statistics and then join
        
        aggregation = [func.mean(col_name).alias(f"{col_name}_mean"), 
                           func.stddev(col_name).alias(f"{col_name}_std"),
                           func.min(col_name).alias(f"{col_name}_min"),
                           func.max(col_name).alias(f"{col_name}_max")
                          ]   
        df_25_75 = (df.select('sheet_id',col_name)
         .withColumn("p25",func.percentile_approx(func.col(col_name), 0.25).over(w))
         .withColumn("p75",func.percentile_approx(func.col(col_name), 0.75).over(w))
         .withColumn("in_range", func.when( ((func.col(col_name) <= func.col('p75')) & (func.col(col_name) >= func.col('p25'))),1).otherwise(0))
         .where(func.col('in_range') == 1)
         .groupby(*groupby_list).agg(*aggregation)
                    )
        dummy_origin= dummy_origin.join(df_25_75,['sheet_id'],'inner')

Solution

  • using Jonathan's logic, we can shorten the code a little by using structs for the percentiles. note that percentile_approx accepts a list of percentages as well and, if passed, it will generate the percentiles as an array where the Nth element in the resulting array is for the Nth element in the passed list of percentages.

    import pyspark.sql.functions as func
    
    value_fields = ['value1', 'value2']
    
    # just thinking ahead for the condition to be used in every column
    column_condition = lambda c: func.col(c).between(func.col('valcol_pers.'+c+'_per')[0], func.col('valcol_pers.'+c+'_per')[1])
    
    # calculate the 0.25 & 0.75 percentiles and store in a struct - `valcol_pers`
    # where the struct field names will indicate the column the percentiles refer to
    data_sdf. \
        withColumn('valcol_pers', 
                   func.struct(*[func.percentile_approx(c, [0.25, 0.75]).over(wd.partitionBy('sheet')).alias(c+'_per') for c in value_fields])
                   ). \
        groupBy('sheet', 'equipment', 'chamber'). \
        agg(*[func.min(func.when(column_condition(c), func.col(c))).alias(c+'_min') for c in value_fields],
            *[func.max(func.when(column_condition(c), func.col(c))).alias(c+'_max') for c in value_fields],
            *[func.mean(func.when(column_condition(c), func.col(c))).alias(c+'_mean') for c in value_fields],
            *[func.stddev(func.when(column_condition(c), func.col(c))).alias(c+'_stddev') for c in value_fields]
            ). \
        show(truncate=False)
    
    # +-----+---------+-------+----------+----------+----------+----------+-----------------+------------------+------------------+------------------+
    # |sheet|equipment|chamber|value1_min|value2_min|value1_max|value2_max|value1_mean      |value2_mean       |value1_stddev     |value2_stddev     |
    # +-----+---------+-------+----------+----------+----------+----------+-----------------+------------------+------------------+------------------+
    # |b    |E1       |C1     |1         |2         |56        |126       |24.0             |69.75             |28.583211855912904|65.62710314090259 |
    # |a    |E1       |C1     |1         |1         |13        |23        |8.666666666666666|15.333333333333334|6.658328118479393 |12.423096769056148|
    # +-----+---------+-------+----------+----------+----------+----------+-----------------+------------------+------------------+------------------+
    

    the valcol_pers field would look like the following

    data_sdf. \
        withColumn('valcol_pers', 
                   func.struct(*[func.percentile_approx(c, [0.25, 0.75]).over(wd.partitionBy('sheet')).alias(c+'_per') for c in value_fields])
                   )
    
    +-----+---------+-------+----+------+------+-------------------+
    |sheet|equipment|chamber|time|value1|value2|valcol_pers        |
    +-----+---------+-------+----+------+------+-------------------+
    |b    |E1       |C1     |1   |1     |2     |{[1, 56], [2, 126]}|
    |b    |E1       |C1     |2   |15    |25    |{[1, 56], [2, 126]}|
    |b    |E1       |C1     |3   |56    |126   |{[1, 56], [2, 126]}|
    |b    |E1       |C1     |4   |156   |126   |{[1, 56], [2, 126]}|
    |a    |E1       |C1     |1   |1     |1     |{[1, 13], [1, 23]} |
    |a    |E1       |C1     |2   |12    |22    |{[1, 13], [1, 23]} |
    |a    |E1       |C1     |3   |13    |23    |{[1, 13], [1, 23]} |
    |a    |E1       |C1     |4   |133   |123   |{[1, 13], [1, 23]} |
    +-----+---------+-------+----+------+------+-------------------+
    
    root
     |-- sheet: string (nullable = true)
     |-- equipment: string (nullable = true)
     |-- chamber: string (nullable = true)
     |-- time: long (nullable = true)
     |-- value1: long (nullable = true)
     |-- value2: long (nullable = true)
     |-- valcol_pers: struct (nullable = false)
     |    |-- value1_per: array (nullable = true)
     |    |    |-- element: long (containsNull = false)
     |    |-- value2_per: array (nullable = true)
     |    |    |-- element: long (containsNull = false)