Search code examples

spark OneHotEncoder - how to exclude user-defined category?

Consider the following spark dataframe:


     |-- predictor: double (nullable = true)
     |-- label: double (nullable = true)
     |-- date: string (nullable = true)

    predictor      label              date    
    4.23           6.33               20160510
    4.77           7.18               20160510
    4.09           5.94               20160511
    4.23           6.33               20160511
    4.77           7.18               20160512
    4.09           5.94               20160512

Essentially, my dataframe consists of data with daily frequency. I need to map the column of dates to a column of binary vectors. This is simple to implement using StringIndexer & OneHotEncoder:

val dateIndexer = new StringIndexer()
val indexed = dateIndexer.transform(df)

val encoder = new OneHotEncoder()

val encoded = encoder.transform(indexed)

My problem is that OneHotEncoder drops the last category by default. However, I need to drop the category which relates to the first date in my dataframe (20160510 in the above example) because I need to compute a time trend relative to the first date.

How can I achieve this for the above example (note that I have more than 100 dates in my dataframe)?


  • You can try setting setDropLast to false:

    val encoder = new OneHotEncoder()
    val encoded = encoder.transform(indexed)

    and dropping level choice manually, using VectorSlicer:

    val slicer = new VectorSlicer()
    |predictor|label|    date|dateIndex| date_codeVec|data_codeVec_selected|
    |     4.23| 6.33|20160510|      0.0|(3,[0],[1.0])|            (2,[],[])|
    |     4.77| 7.18|20160510|      0.0|(3,[0],[1.0])|            (2,[],[])|
    |     4.09| 5.94|20160511|      2.0|(3,[2],[1.0])|        (2,[1],[1.0])|
    |     4.23| 6.33|20160511|      2.0|(3,[2],[1.0])|        (2,[1],[1.0])|
    |     4.77| 7.18|20160512|      1.0|(3,[1],[1.0])|        (2,[0],[1.0])|
    |     4.09| 5.94|20160512|      1.0|(3,[1],[1.0])|        (2,[0],[1.0])|