I have a data frame with two columns,
+---+-------+
| id| fruit|
+---+-------+
| 0| apple|
| 1| banana|
| 2|coconut|
| 1| banana|
| 2|coconut|
+---+-------+
also I have a universal List with all the items,
fruitList: Seq[String] = WrappedArray(apple, coconut, banana)
now I want to create a new column in the dataframe with an array of 1's,0's, where 1 represent the item exist and 0 if the item doesn't present for that row.
Desired Output
+---+-----------+
| id| fruitlist|
+---+-----------+
| 0| [1,0,0] |
| 1| [0,1,0] |
| 2|[0,0,1] |
| 1| [0,1,0] |
| 2|[0,0,1] |
+---+-----------+
This is something I tried,
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}
val df = spark.createDataFrame(Seq(
(0, "apple"),
(1, "banana"),
(2, "coconut"),
(1, "banana"),
(2, "coconut")
)).toDF("id", "fruit")
df.show
import org.apache.spark.sql.functions._
val fruitList = df.select(collect_set("fruit")).first().getAs[Seq[String]](0)
print(fruitList)
I tried to solve this with OneHotEncoder but the result was something like this after converting to dense vector, which is not what I needed.
+---+-------+----------+-------------+---------+
| id| fruit|fruitIndex| fruitVec| vd|
+---+-------+----------+-------------+---------+
| 0| apple| 2.0| (2,[],[])|[0.0,0.0]|
| 1| banana| 1.0|(2,[1],[1.0])|[0.0,1.0]|
| 2|coconut| 0.0|(2,[0],[1.0])|[1.0,0.0]|
| 1| banana| 1.0|(2,[1],[1.0])|[0.0,1.0]|
| 2|coconut| 0.0|(2,[0],[1.0])|[1.0,0.0]|
+---+-------+----------+-------------+---------+
If you have a collection as
val fruitList: Seq[String] = Array("apple", "coconut", "banana")
Then you can either do it using inbuilt functions or udf function
import org.apache.spark.sql.functions._
df.withColumn("fruitList", array(fruitList.map(x => when(lit(x) === col("fruit"),1).otherwise(0)): _*)).show(false)
import org.apache.spark.sql.functions._
def containedUdf = udf((fruit: String) => fruitList.map(x => if(x == fruit) 1 else 0))
df.withColumn("fruitList", containedUdf(col("fruit"))).show(false)
which should give you
+---+-------+---------+
|id |fruit |fruitList|
+---+-------+---------+
|0 |apple |[1, 0, 0]|
|1 |banana |[0, 0, 1]|
|2 |coconut|[0, 1, 0]|
|1 |banana |[0, 0, 1]|
|2 |coconut|[0, 1, 0]|
+---+-------+---------+
udf functions are easy to understand and straight forward, dealing with primitive datatypes but should be avoided if optimized and fast inbuilt functions are available to do the same task
I hope the answer is helpful