Search code examples
pythonapache-sparkpysparkparametersaggregate

Aggregations according to boolean parameter inside function


I have a custom function for starting a Spark job. The main goal is to group a table with multiple aggregations:

.groupby(["some", "columns")
.agg(
   F.mean("col1").alias("col1_mean"),
   F.sum("col1").alias("col1_sum")                         
)

Now I'd like to be more flexible with the aggregations. Is there a way to in-/exclude aggregations according to a boolean value? Something like:

def spark_function(mean_agg=True, sum_agg=False):
   [...]
   .groupby(["some", "columns")
   .agg(
      F.mean("col1").alias("col1_mean"), # only if mean_agg=True
      F.sum("col1").alias("col1_sum") # only if sum_agg=True                          
   )
   [...]

In real world, there are some aggregations that will always be done, no need to check if at least one is True.


Solution

  • Yes you can, but you will have to experiment on this yourself, as only you know your true needs. I'll show an example. I don't suggest putting groupBy into the function, as you only write it once - it's not repetitive.

    Input df:

    from pyspark.sql import functions as F
    df = spark.createDataFrame(
        [('a', 3,),
         ('a', 5,)],
        ['id', 'col1']
    )
    

    Example function:

    def spark_agg_function(cols, sum_agg=False):
        aggs = [F.mean(c).alias(f"{c}_mean") for c in cols]
        if sum_agg:
            aggs += [F.sum(c).alias(f"{c}_sum") for c in cols]
        return aggs
    

    Test:

    df.groupBy('id').agg(
        *spark_agg_function(['col1'])
    ).show()
    # +---+---------+
    # | id|col1_mean|
    # +---+---------+
    # |  a|      4.0|
    # +---+---------+
    
    df.groupBy('id').agg(
        *spark_agg_function(['col1'], sum_agg=True)
    ).show()
    # +---+---------+--------+
    # | id|col1_mean|col1_sum|
    # +---+---------+--------+
    # |  a|      4.0|       8|
    # +---+---------+--------+