Search code examples
pysparkapache-spark-sqluser-defined-functions

Pyspark: pass multiple columns in pandas_udf


My problem is similar to this one but instead of udf I need to use pandas_udf.

I have a spark data frame with many columns (number of columns varies) and I need to apply on them a custom function (for example sum). I know I can hard-code column names but it does not work when the number of columns varies.

Please see examples:

enter image description here


Solution

  • The solution is to use the *expression in the function call and pd.concat method inside the pandas_udf function body

    >>> import pandas as pd
    >>> import pyspark.sql.functions as F
    
    >>> @F.pandas_udf("double")
    ... def col_sum(*args: pd.Series) -> pd.Series:
    ...     pdf = pd.concat(args, axis=1)
    ...     col_sum = pdf.sum(axis=1)
    ...     return col_sum
    ... 
    
    >>> df = spark.createDataFrame([(1,1,1),(2,2,2),(3,3,3)],["A","B","C"])
    >>> df.withColumn('SUM', col_sum(*df.columns)).show()
    +---+---+---+---+                                                               
    |  A|  B|  C|SUM|
    +---+---+---+---+
    |  1|  1|  1|3.0|
    |  2|  2|  2|6.0|
    |  3|  3|  3|9.0|
    +---+---+---+---+
    
    >>> df = spark.createDataFrame([(1,1,1,1),(2,2,2,2),(3,3,3,3)],["A","B","C"])
    >>> df.withColumn('SUM', col_sum(*df.columns)).show()
    +---+---+---+---+----+
    |  A|  B|  C| _4| SUM|
    +---+---+---+---+----+
    |  1|  1|  1|  1| 4.0|
    |  2|  2|  2|  2| 8.0|
    |  3|  3|  3|  3|12.0|
    +---+---+---+---+----+