Search code examples
apache-sparkpysparkapache-spark-sqluser-defined-functionspyspark-pandas

Create column using Spark pandas_udf, with dynamic number of input columns


I have this df:

df = spark.createDataFrame(
    [('row_a', 5.0, 0.0, 11.0),
     ('row_b', 3394.0, 0.0, 4543.0),
     ('row_c', 136111.0, 0.0, 219255.0),
     ('row_d', 0.0, 0.0, 0.0),
     ('row_e', 0.0, 0.0, 0.0),
     ('row_f', 42.0, 0.0, 54.0)],
    ['value', 'col_a', 'col_b', 'col_c']
)

I would like to use .quantile(0.25, axis=1) from Pandas which would add one column:

import pandas as pd
pdf = df.toPandas()
pdf['25%'] = pdf.quantile(0.25, axis=1)
print(pdf)
#    value     col_a  col_b     col_c      25%
# 0  row_a       5.0    0.0      11.0      2.5
# 1  row_b    3394.0    0.0    4543.0   1697.0
# 2  row_c  136111.0    0.0  219255.0  68055.5
# 3  row_d       0.0    0.0       0.0      0.0
# 4  row_e       0.0    0.0       0.0      0.0
# 5  row_f      42.0    0.0      54.0     21.0

Performance to me is important, so I assume pandas_udf from pyspark.sql.functions could do it in a more optimized way. But I struggle to make a performant and useful function. This is my best attempt:

from pyspark.sql import functions as F
import pandas as pd
@F.pandas_udf('double')
def quartile1_on_axis1(a: pd.Series, b: pd.Series, c: pd.Series) -> pd.Series:
    pdf = pd.DataFrame({'a':a, 'b':b, 'c':c})
    return pdf.quantile(0.25, axis=1)

df = df.withColumn('25%', quartile1_on_axis1('col_a', 'col_b', 'col_c'))
  1. I don't like that I need an argument for every column and later in the function addressing those arguments separately to create a df. All of those columns serve the same purpose, so IMHO there should be a way to address them all together, something like in this pseudocode:

    def quartile1_on_axis1(*cols) -> pd.Series:
        pdf = pd.DataFrame(cols)
    

    This way I could use this function for any number of columns.

  2. Is it necessary to create a pd.Dataframe inside the UDF? To me this seems the same as without UDF (Spark df -> Pandas df -> Spark df), as shown above. Without UDF it's even shorter. Should I really try to make it work with pandas_udf performance-wise? I think pandas_udf was designed specifically for this kind of purpose...


Solution

  • You can pass a single struct column instead of using multiple columns like this:

    @F.pandas_udf('double')
    def quartile1_on_axis1(s: pd.DataFrame) -> pd.Series:
        return s.quantile(0.25, axis=1)
    
    
    cols = ['col_a', 'col_b', 'col_c']
    
    df = df.withColumn('25%', quartile1_on_axis1(F.struct(*cols)))
    df.show()
    
    # +-----+--------+-----+--------+-------+
    # |value|   col_a|col_b|   col_c|    25%|
    # +-----+--------+-----+--------+-------+
    # |row_a|     5.0|  0.0|    11.0|    2.5|
    # |row_b|  3394.0|  0.0|  4543.0| 1697.0|
    # |row_c|136111.0|  0.0|219255.0|68055.5|
    # |row_d|     0.0|  0.0|     0.0|    0.0|
    # |row_e|     0.0|  0.0|     0.0|    0.0|
    # |row_f|    42.0|  0.0|    54.0|   21.0|
    # +-----+--------+-----+--------+-------+
    

    pyspark.sql.functions.pandas_udf

    Note that the type hint should use pandas.Series in all cases but there is one variant that pandas.DataFrame should be used for its input or output type hint instead when the input or output column is of pyspark.sql.types.StructType.