Search code examples
scalaapache-sparkapache-spark-sqlapache-spark-ml

Spark custom preprocessing estimator


I want to write a custom Estimator for spark's Pipelines. It should perform data cleaning tasks. This means some rows will be dropped, some columns dropped, some columns added, some values replaced in existing columns. IT should also store the mean or min for some numeric columns as a NaN replacement.

However,

override def transformSchema(schema: StructType): StructType = {
   schema.add(StructField("foo", IntegerType))
}

only supports adding fields? I am curious how am I supposed to handle this.


Solution

  • You are correct that only adding fields is supported by the StructField api. However, that does not mean you cannot remove fields, too!

    StructType has a value member fields, which gives you an Array[StructField]. You can .filter() this array however you see fit (by name, dataType, or something more complicated), keeping only the columns you want.

    Once you've done your filtering, you have two options:

    1. add a StructField for each new column to the filtered fields array and construct a StructType from this
    2. construct a StructType from the fields array and add new columns using .add(...).