Search code examples
scalaapache-sparkapache-spark-mllibapache-spark-ml

Spark ML - Save OneVsRestModel


I am in the middle of refactoring my code to take advantage of DataFrames, Estimators, and Pipelines. I was originally using MLlib Multiclass LogisticRegressionWithLBFGS on RDD[LabeledPoint]. I am enjoying learning and using the new API, but I am not sure how to save my new model and apply it on new data.

Currently, the ML implementation of LogisticRegression only supports binary classification. I am, instead using OneVsRest like so:

val lr = new LogisticRegression().setFitIntercept(true)
val ovr = new OneVsRest()
ovr.setClassifier(lr)
val ovrModel = ovr.fit(training)

I would now like to save my OneVsRestModel, but this does not seem to be supported by the API. I have tried:

ovrModel.save("my-ovr") // Cannot resolve symbol save
ovrModel.models.foreach(_.save("model-" + _.uid)) // Cannot resolve symbol save

Is there a way to save this, so I can load it in a new application for making new predictions?


Solution

  • Spark 2.0.0

    OneVsRestModel implements MLWritable so it should be possible to save it directly. Method shown below can be still useful to save individual models separately.

    Spark < 2.0.0

    The problem here is that models returns an Array of ClassificationModel[_, _]] not an Array of LogisticRegressionModel (or MLWritable). To make it work you'll have to be specific about the types:

    import org.apache.spark.ml.classification.LogisticRegressionModel
    
    ovrModel.models.zipWithIndex.foreach { 
      case (model: LogisticRegressionModel, i: Int) => 
        model.save(s"model-${model.uid}-$i")
    }
    

    or to be more generic:

    import org.apache.spark.ml.util.MLWritable
    
    ovrModel.models.zipWithIndex.foreach { 
      case (model: MLWritable, i: Int) =>
        model.save(s"model-${model.uid}-$i")
    }
    

    Unfortunately as for now (Spark 1.6) OneVsRestModel doesn't implement MLWritable so it cannot be saved alone.

    Note:

    All models int the OneVsRest seem to use the same uid hence we need an explicit index. It will be also useful to identify the model later.