My problem is similar to this one but instead of udf
I need to use pandas_udf
.
I have a spark data frame with many columns (number of columns varies) and I need to apply on them a custom function (for example sum). I know I can hard-code column names but it does not work when the number of columns varies.
Please see examples:
The solution is to use the *expression
in the function call and pd.concat
method inside the pandas_udf
function body
>>> import pandas as pd
>>> import pyspark.sql.functions as F
>>> @F.pandas_udf("double")
... def col_sum(*args: pd.Series) -> pd.Series:
... pdf = pd.concat(args, axis=1)
... col_sum = pdf.sum(axis=1)
... return col_sum
...
>>> df = spark.createDataFrame([(1,1,1),(2,2,2),(3,3,3)],["A","B","C"])
>>> df.withColumn('SUM', col_sum(*df.columns)).show()
+---+---+---+---+
| A| B| C|SUM|
+---+---+---+---+
| 1| 1| 1|3.0|
| 2| 2| 2|6.0|
| 3| 3| 3|9.0|
+---+---+---+---+
>>> df = spark.createDataFrame([(1,1,1,1),(2,2,2,2),(3,3,3,3)],["A","B","C"])
>>> df.withColumn('SUM', col_sum(*df.columns)).show()
+---+---+---+---+----+
| A| B| C| _4| SUM|
+---+---+---+---+----+
| 1| 1| 1| 1| 4.0|
| 2| 2| 2| 2| 8.0|
| 3| 3| 3| 3|12.0|
+---+---+---+---+----+