Search code examples
apache-sparkapache-spark-mllibapache-spark-ml

Spark struct represented by OneHotEncoder


I have a data frame with two columns,

+---+-------+
| id|  fruit|
+---+-------+
|  0|  apple|
|  1| banana|
|  2|coconut|
|  1| banana|
|  2|coconut|
+---+-------+

also I have a universal List with all the items,

fruitList: Seq[String] = WrappedArray(apple, coconut, banana)

now I want to create a new column in the dataframe with an array of 1's,0's, where 1 represent the item exist and 0 if the item doesn't present for that row.

Desired Output

    +---+-----------+
    | id|  fruitlist|
    +---+-----------+
    |  0|  [1,0,0]  |
    |  1| [0,1,0]   |
    |  2|[0,0,1]    |
    |  1| [0,1,0]   |
    |  2|[0,0,1]    |
    +---+-----------+

This is something I tried,

import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}

val df = spark.createDataFrame(Seq(
  (0, "apple"),
  (1, "banana"),
  (2, "coconut"),
  (1, "banana"),
  (2, "coconut")
)).toDF("id", "fruit")

df.show
import org.apache.spark.sql.functions._
val fruitList = df.select(collect_set("fruit")).first().getAs[Seq[String]](0)
print(fruitList)

I tried to solve this with OneHotEncoder but the result was something like this after converting to dense vector, which is not what I needed.

    +---+-------+----------+-------------+---------+
| id|  fruit|fruitIndex|     fruitVec|       vd|
+---+-------+----------+-------------+---------+
|  0|  apple|       2.0|    (2,[],[])|[0.0,0.0]|
|  1| banana|       1.0|(2,[1],[1.0])|[0.0,1.0]|
|  2|coconut|       0.0|(2,[0],[1.0])|[1.0,0.0]|
|  1| banana|       1.0|(2,[1],[1.0])|[0.0,1.0]|
|  2|coconut|       0.0|(2,[0],[1.0])|[1.0,0.0]|
+---+-------+----------+-------------+---------+

Solution

  • If you have a collection as

    val fruitList: Seq[String] = Array("apple", "coconut", "banana")
    

    Then you can either do it using inbuilt functions or udf function

    inbuilt functions (array, when and lit)

    import org.apache.spark.sql.functions._
    df.withColumn("fruitList", array(fruitList.map(x => when(lit(x) === col("fruit"),1).otherwise(0)): _*)).show(false)
    

    udf function

    import org.apache.spark.sql.functions._
    def containedUdf = udf((fruit: String) => fruitList.map(x => if(x == fruit) 1 else 0))
    
    df.withColumn("fruitList", containedUdf(col("fruit"))).show(false)
    

    which should give you

    +---+-------+---------+
    |id |fruit  |fruitList|
    +---+-------+---------+
    |0  |apple  |[1, 0, 0]|
    |1  |banana |[0, 0, 1]|
    |2  |coconut|[0, 1, 0]|
    |1  |banana |[0, 0, 1]|
    |2  |coconut|[0, 1, 0]|
    +---+-------+---------+
    

    udf functions are easy to understand and straight forward, dealing with primitive datatypes but should be avoided if optimized and fast inbuilt functions are available to do the same task

    I hope the answer is helpful