Search code examples
scalaapache-sparkapache-spark-sqlsubsetarray-column

Comparing two array columns in Scala Spark


I have a dataframe of format given below.

movieId1 | genreList1              | genreList2
--------------------------------------------------
1        |[Adventure,Comedy]       |[Adventure]
2        |[Animation,Drama,War]    |[War,Drama]
3        |[Adventure,Drama]        |[Drama,War]

and trying to create another flag column which shows whether genreList2 is a subset of genreList1.

movieId1 | genreList1              | genreList2        | Flag
---------------------------------------------------------------
1        |[Adventure,Comedy]       | [Adventure]       |1
2        |[Animation,Drama,War]    | [War,Drama]       |1
3        |[Adventure,Drama]        | [Drama,War]       |0

I have tried this:

def intersect_check(a: Array[String], b: Array[String]): Int = {
  if (b.sameElements(a.intersect(b))) { return 1 } 
  else { return 2 }
}

def intersect_check_udf =
  udf((colvalue1: Array[String], colvalue2: Array[String]) => intersect_check(colvalue1, colvalue2))

data = data.withColumn("Flag", intersect_check_udf(col("genreList1"), col("genreList2")))

But this throws error

org.apache.spark.SparkException: Failed to execute user defined function.

P.S.: The above function (intersect_check) works for Arrays.


Solution

  • We can define an udf that calculates the length of the intersection between the two Array columns and checks whether it is equal to the length of the second column. If so, the second array is a subset of the first one.

    Also, the inputs of your udf need to be class WrappedArray[String], not Array[String] :

    import scala.collection.mutable.WrappedArray
    import org.apache.spark.sql.functions.col
    
    val same_elements = udf { (a: WrappedArray[String], 
                               b: WrappedArray[String]) => 
      if (a.intersect(b).length == b.length){ 1 }else{ 0 }  
    }
    
    df.withColumn("test",same_elements(col("genreList1"),col("genreList2")))
      .show(truncate = false)
    +--------+-----------------------+------------+----+
    |movieId1|genreList1             |genreList2  |test|
    +--------+-----------------------+------------+----+
    |1       |[Adventure, Comedy]    |[Adventure] |1   |
    |2       |[Animation, Drama, War]|[War, Drama]|1   |
    |3       |[Adventure, Drama]     |[Drama, War]|0   |
    +--------+-----------------------+------------+----+
    

    Data

    val df = List((1,Array("Adventure","Comedy"), Array("Adventure")),
                  (2,Array("Animation","Drama","War"), Array("War","Drama")),
                  (3,Array("Adventure","Drama"),Array("Drama","War"))).toDF("movieId1","genreList1","genreList2")