Search code examples
scalaapache-sparkone-hot-encoding

spark OneHotEncoder - how to exclude user-defined category?


Consider the following spark dataframe:

df.printSchema()

     |-- predictor: double (nullable = true)
     |-- label: double (nullable = true)
     |-- date: string (nullable = true)

df.show(6)

    predictor      label              date    
    4.23           6.33               20160510
    4.77           7.18               20160510
    4.09           5.94               20160511
    4.23           6.33               20160511
    4.77           7.18               20160512
    4.09           5.94               20160512

Essentially, my dataframe consists of data with daily frequency. I need to map the column of dates to a column of binary vectors. This is simple to implement using StringIndexer & OneHotEncoder:

val dateIndexer = new StringIndexer()
  .setInputCol("date")
  .setOutputCol("dateIndex")
  .fit(df)
val indexed = dateIndexer.transform(df)

val encoder = new OneHotEncoder()
  .setInputCol("dateIndex")
  .setOutputCol("date_codeVec")

val encoded = encoder.transform(indexed)

My problem is that OneHotEncoder drops the last category by default. However, I need to drop the category which relates to the first date in my dataframe (20160510 in the above example) because I need to compute a time trend relative to the first date.

How can I achieve this for the above example (note that I have more than 100 dates in my dataframe)?


Solution

  • You can try setting setDropLast to false:

    val encoder = new OneHotEncoder()
      .setInputCol("dateIndex")
      .setOutputCol("date_codeVec")
      .setDropLast(false)
    
    val encoded = encoder.transform(indexed)
    

    and dropping level choice manually, using VectorSlicer:

    import org.apache.spark.ml.feature.VectorSlicer
    
    val slicer = new VectorSlicer()
      .setInputCol("date_codeVec")
      .setOutputCol("data_codeVec_selected")
      .setNames(dateIndexer.labels.diff(Seq(dateIndexer.labels.min)))
    
    slicer.transform(encoded)
    
    +---------+-----+--------+---------+-------------+---------------------+
    |predictor|label|    date|dateIndex| date_codeVec|data_codeVec_selected|
    +---------+-----+--------+---------+-------------+---------------------+
    |     4.23| 6.33|20160510|      0.0|(3,[0],[1.0])|            (2,[],[])|
    |     4.77| 7.18|20160510|      0.0|(3,[0],[1.0])|            (2,[],[])|
    |     4.09| 5.94|20160511|      2.0|(3,[2],[1.0])|        (2,[1],[1.0])|
    |     4.23| 6.33|20160511|      2.0|(3,[2],[1.0])|        (2,[1],[1.0])|
    |     4.77| 7.18|20160512|      1.0|(3,[1],[1.0])|        (2,[0],[1.0])|
    |     4.09| 5.94|20160512|      1.0|(3,[1],[1.0])|        (2,[0],[1.0])|
    +---------+-----+--------+---------+-------------+---------------------+