I am using a Grouped Agg Pandas UDF to average the values of an array column element-wise (aka mean pooling). I keep getting the following warning and have not been able to find the correct type hints to provide for PandasUDFType.GROUPED_AGG
with ArrayType(DoubleType())
.
UserWarning: In Python 3.6+ and Spark 3.0+, it is preferred to specify type hints for pandas UDF instead of specifying pandas UDF type which will be deprecated in the future releases. See SPARK-28264 for more details.
Reference code with data
from pyspark.sql.types import ArrayType, DoubleType
import pyspark.sql.functions as F
import numpy as np
import pandas as pd
# create data
pdf = pd.DataFrame(
{"id": [1, 1, 2, 2], "x": [[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0]]}
)
df = spark.createDataFrame(pdf)
df.show()
# define pooling udf
@F.pandas_udf(
returnType=ArrayType(DoubleType()), functionType=F.PandasUDFType.GROUPED_AGG
)
def mean_pooling_udf(x: pd.Series) -> pd.Series:
return np.mean(x, axis=0)
# apply udf and display
df.groupby("id").agg(mean_pooling_udf(df["x"])).show()
Output
+---+----------+
| id| x|
+---+----------+
| 1|[1.0, 1.0]|
| 1|[2.0, 2.0]|
| 2|[3.0, 3.0]|
| 2|[4.0, 4.0]|
+---+----------+
+---+-------------------+
| id|mean_pooling_udf(x)|
+---+-------------------+
| 1| [1.5, 1.5]|
| 2| [3.5, 3.5]|
+---+-------------------+
You need to use the Data Definition Language (DDL) for your return type: https://vincent.doba.fr/posts/20211004_spark_data_description_language_for_defining_spark_schema/
# define pooling udf
@pandas_udf('array<double>')
def mean_pooling_udf(x: pd.Series) -> np.ndarray:
return np.mean(x, axis=0)