Search code examples
scalaapache-sparkdataframeuser-defined-functions

Get distinct elements from rows of type ArrayType in Spark dataframe column


I have a dataframe with the following schema:

    root
     |-- e: array (nullable = true)
     |    |-- element: string (containsNull = true)

For example, initiate a dataframe:

val df = Seq(Seq("73","73"), null, null, null, Seq("51"), null, null, null, Seq("52", "53", "53", "73", "84"), Seq("73", "72", "51", "73")).toDF("e")

df.show()

+--------------------+
|                   e|
+--------------------+
|            [73, 73]|
|                null|
|                null|
|                null|
|                [51]|
|                null|
|                null|
|                null|
|[52, 53, 53, 73, 84]|
|    [73, 72, 51, 73]|
+--------------------+

I'd like the output to be:

+--------------------+
|                   e|
+--------------------+
|                [73]|
|                null|
|                null|
|                null|
|                [51]|
|                null|
|                null|
|                null|
|    [52, 53, 73, 84]|
|        [73, 72, 51]|
+--------------------+

I am trying the following udf:

def distinct(arr: TraversableOnce[String])=arr.toList.distinct
val distinctUDF=udf(distinct(_:Traversable[String]))

But it only works when the rows aren't null i.e.

df.filter($"e".isNotNull).select(distinctUDF($"e")) 

gives me

+----------------+
|          UDF(e)|
+----------------+
|            [73]|
|            [51]|
|[52, 53, 73, 84]|
|    [73, 72, 51]|
+----------------+

but

df.select(distinctUDF($"e")) 

fails. How do I make the udf handle null in this case? Alternatively, if there's a simpler way of getting the unique values, I'd like to try that.


Solution

  • You can make use of when().otherwise() to apply your UDF only when the column value is not null. In this case, .otherwise(null) can also be skipped, as it defaults to null when not specifying the otherwise clause.

    val distinctUDF = udf( (s: Seq[String]) => s.distinct )
    
    df.select(when($"e".isNotNull, distinctUDF($"e")).as("e"))