Search code examples
apache-sparkpipelineapache-spark-ml

Dealing with dynamic columns with VectorAssembler


Using sparks vector assembler the columns to be assembled need to be defined up front.

However, if using the vector-assembler in a pipeline where the previous steps will modify the columns of the data frame how can I specify the columns without hard coding all the value manually?

As df.columns will not contain the right values when the constructor is called of vector-assembler currently I do not see another way to handle that or to split the pipeline - which is bad as well because CrossValidator will no longer properly work.

val vectorAssembler = new VectorAssembler()
    .setInputCols(df.columns
      .filter(!_.contains("target"))
      .filter(!_.contains("idNumber")))
    .setOutputCol("features")

edit

initial df of

---+------+---+-
|foo|   id|baz|
+---+------+---+
|  0| 1    |  A|
|  1|2     |  A|
|  0| 3    |  null|
|  1| 4    |  C|
+---+------+---+

will be transformed as follows. You can see that nan values will be imputed for original columns with most frequent and some features derived e.g. as outlined here isA which is 1 if baz is A, 0 otherwise and if null originally N

+---+------+---+-------+
|foo|id    |baz| isA    |
+---+------+---+-------+
|  0| 1    |  A| 1      |
|  1|2     |  A|1       |
|  0| 3    |   A|    n  |
|  1| 4    |  C|    0   |
+---+------+---+-------+

Later on in the pipeline, a stringIndexer is used to make the data fit for ML / vectorAssembler.

isA is not present in the original df, but not the "only" output column all the columns in this frame except foo and an id column should be transformed by the vector assembler.

I hope it is clearer now.


Solution

  • I created a custom vector assembler (1:1 copy of original) and then changed it to include all columns except some which are passed to be excluded.

    edit

    To make it a bit clearer

    def setInputColsExcept(value: Array[String]): this.type = set(inputCols, value)
    

    specifies which columns should be excluded. And then

    val remainingColumns = dataset.columns.filter(!$(inputCols).contains(_))
    

    in the transform method is filtering for desired columns.