Search code examples
apache-sparkapache-spark-sqlpipelineapache-spark-mllibapache-spark-ml

Spark custom estimator including persistence


I want to develop a custom estimator for spark which handles persistence of the great pipeline API as well. But as How to Roll a Custom Estimator in PySpark mllib put it there is not a lot of documentation out there (yet).

I have some data cleansing code written in spark and would like to wrap it in a custom estimator. Some na-substitutions, column deletions, filtering and basic feature generation are included (e.g. birthdate to age).

  • transformSchema will use the case class of the dataset ScalaReflection.schemaFor[MyClass].dataType.asInstanceOf[StructType]
  • fit will only fit e.g. mean age as na. substitutes

What is still pretty unclear to me:

  • transform in the custom pipeline model will be used to transform the "fitted" Estimator on new data. Is this correct? If yes how should I transfer the fitted values e.g. the mean age from above into the model?

  • how to handle persistence? I found some generic loadImpl method within private spark components but am unsure how to transfer my own parameters e.g. the mean age into the MLReader / MLWriter which are used for serialization.

It would be great if you could help me with a custom estimator - especially with the persistence part.


Solution

  • First of all I believe you're mixing a bit two different things:

    • Estimators - which represent stages that can be fit-ted. Estimator fit method takes Dataset and returns Transformer (model).
    • Transformers - which represent stages that can transform data.

    When you fit Pipeline it fits all Estimators and returns PipelineModel. PipelineModel can transform data sequentially calling transform on all Transformers in the the model.

    how should I transfer the fitted values

    There is no single answer to this question. In general you have two options:

    • Pass parameters of the fitted model as the arguments of the Transformer.
    • Make parameters of the fitted model Params of the Transformer.

    The first approach is typically used by the built-in Transformer, but the second one should work in some simple cases.

    how to handle persistence

    • If Transformer is defined only by its Params you can extend DefaultParamsReadable.
    • If you use more complex arguments you should extend MLWritable and implement MLWriter that makes sense for your data. There are multiple examples in Spark source which show how to implement data and metadata reading / writing.

    If you're looking for an easy to comprehend example take a look a the CountVectorizer(Model) where: