Similar to this question (Scala), but I need combinations in PySpark (pair combinations of array column).
Example input:
df = spark.createDataFrame(
[([0, 1],),
([2, 3, 4],),
([5, 6, 7, 8],)],
['array_col'])
Expected output:
+------------+------------------------------------------------+
|array_col |out |
+------------+------------------------------------------------+
|[0, 1] |[[0, 1]] |
|[2, 3, 4] |[[2, 3], [2, 4], [3, 4]] |
|[5, 6, 7, 8]|[[5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [7, 8]]|
+------------+------------------------------------------------+
pandas_udf
is an efficient and concise approach in PySpark.
from pyspark.sql import functions as F
import pandas as pd
from itertools import combinations
@F.pandas_udf('array<array<int>>')
def pudf(c: pd.Series) -> pd.Series:
return c.apply(lambda x: list(combinations(x, 2)))
df = df.withColumn('out', pudf('array_col'))
df.show(truncate=0)
# +------------+------------------------------------------------+
# |array_col |out |
# +------------+------------------------------------------------+
# |[0, 1] |[[0, 1]] |
# |[2, 3, 4] |[[2, 3], [2, 4], [3, 4]] |
# |[5, 6, 7, 8]|[[5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [7, 8]]|
# +------------+------------------------------------------------+
Note: in some systems, instead of 'array<array<int>>'
you may need to provide types from pyspark.sql.types
, e.g. ArrayType(ArrayType(IntegerType()))
.