Search code examples
pandasapache-sparkpysparkapache-spark-sql

how to use @pandas_udf of pyspark for groupby.agg


I am using the pandas API on Spark. And I am using the groupby.agg operation. I found a similar issue, but the solution does not work for me. I also checked the official docs. But the docs does not provide enough examples.

Here is my sample data:

L_SHIPMODE O_ORDERPRIORITY
0 MAIL 2 -HIGH
1 SHIP 1 -URGENT

And I just want to groupby L_SHIPMODE and count on O_ORDERPRIORITY by the following udf:

@pandas_udf(IntegerType())
def g1(x):
  return ((x == "1-URGENT") | (x == "2-HIGH")).sum()

@pandas_udf(IntegerType())
def g2(x):
  return ((x != "1-URGENT") & (x != "2-HIGH")).sum()

# tryied register, but it seems this is not the problem
# spark.udf.register('g1_udf', g1)
# spark.udf.register('g2_udf', g2)

total = jn.groupby("L_SHIPMODE", as_index=False)["O_ORDERPRIORITY"].agg({"O_ORDERPRIORITY": [g1, g2]})

I got:

ValueError: aggs must be a dict mapping from column name to aggregate functions (string or list of strings).

Are there any detailed examples on how to use UDF on groupby.agg?


Solution

  • I think you can avoid using udfs altogether if you create some new columns before your groupby (using a sample dataframe similar to yours):

    import pyspark.pandas as ps
    
    jn = ps.DataFrame({
        'L_SHIPMODE': ['MAIL']*4+['SHIP']*4,
        'O_ORDERPRIORITY': [
            '1 -URGENT', '2 -HIGH', '3 -OTHER', '4 -OTHER',
            '3 -OTHER', '4 -OTHER', '3 -OTHER', '4 -OTHER',
        ]
    })
    
        L_SHIPMODE  O_ORDERPRIORITY
    0   MAIL    1 -URGENT
    1   MAIL    2 -HIGH
    2   MAIL    3 -OTHER
    3   MAIL    4 -OTHER
    4   SHIP    3 -OTHER
    5   SHIP    4 -OTHER
    6   SHIP    3 -OTHER
    7   SHIP    4 -OTHER
    
    jn['1_OR_2'] = jn['O_ORDERPRIORITY'].isin(["1 -URGENT", "2 -HIGH"]).astype('long')
    jn['not_1_OR_2'] = (~jn['O_ORDERPRIORITY'].isin(["1 -URGENT", "2 -HIGH"])).astype('long')
    
    jn.groupby(
        'L_SHIPMODE'
    ).agg({
        '1_OR_2': 'sum',
        'not_1_OR_2': 'sum'
    })
    

    Result:

        1_OR_2  not_1_OR_2
    L_SHIPMODE      
    MAIL    2   2
    SHIP    0   4