Search code examples
scalaapache-sparkapache-spark-sqluser-defined-functions

How to get all combinations of an array column in Spark?


Suppose I have an array column group_ids

+-------+----------+
|user_id|group_ids |
+-------+----------+
|1      |[5, 8]    |
|3      |[1, 2, 3] |
|2      |[1, 4]    |
+-------+----------+

Schema:

root
 |-- user_id: integer (nullable = false)
 |-- group_ids: array (nullable = false)
 |    |-- element: integer (containsNull = false)

I want to get all combinations of pairs:

+-------+------------------------+
|user_id|group_ids               |
+-------+------------------------+
|1      |[[5, 8]]                |
|3      |[[1, 2], [1, 3], [2, 3]]|
|2      |[[1, 4]]                |
+-------+------------------------+

So far I created the easiest solution with UDF:

spark.udf.register("permutate", udf((xs: Seq[Int]) => xs.combinations(2).toSeq))

dataset.withColumn("group_ids", expr("permutate(group_ids)"))

What I'm looking for is something that implemented via Spark Built-in functions. Is there a way to implement the same code without UDF?


Solution

  • Some higher order functions can do the trick. Requires Spark >= 2.4.

    val df2 = df.withColumn(
        "group_ids", 
        expr("""
            filter(
                transform(
                    flatten(
                        transform(
                            group_ids, 
                            x -> arrays_zip(
                                array_repeat(x, size(group_ids)), 
                                group_ids
                            )
                        )
                    ), 
                    x -> array(x['0'], x['group_ids'])
                ), 
                x -> x[0] < x[1]
            )
        """)
    )
    
    
    df2.show(false)
    +-------+------------------------+
    |user_id|group_ids               |
    +-------+------------------------+
    |1      |[[5, 8]]                |
    |3      |[[1, 2], [1, 3], [2, 3]]|
    |2      |[[1, 4]]                |
    +-------+------------------------+