Search code examples
javaapache-sparkapache-spark-ml

Spark ML Pipeline api save not working


in version 1.6 the pipeline api got a new set of features to save and load pipeline stages. I tried to save a stage to disk after I trained a classifier and load it later again to reuse it and save the effort to compute to model again.

For some reason when I save the model, the directory only contains the metadata directory. When I try to load it again I get the following exception:

Exception in thread "main" java.lang.UnsupportedOperationException: empty collection at org.apache.spark.rdd.RDD$$anonfun$first$1.apply(RDD.scala:1330) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:150) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:111) at org.apache.spark.rdd.RDD.withScope(RDD.scala:316) at org.apache.spark.rdd.RDD.first(RDD.scala:1327) at org.apache.spark.ml.util.DefaultParamsReader$.loadMetadata(ReadWrite.scala:284) at org.apache.spark.ml.tuning.CrossValidator$SharedReadWrite$.load(CrossValidator.scala:287) at org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelReader.load(CrossValidator.scala:393) at org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelReader.load(CrossValidator.scala:384) at org.apache.spark.ml.util.MLReadable$class.load(ReadWrite.scala:176) at org.apache.spark.ml.tuning.CrossValidatorModel$.load(CrossValidator.scala:368) at org.apache.spark.ml.tuning.CrossValidatorModel.load(CrossValidator.scala) at org.test.categoryminer.spark.SparkTextClassifierModelCache.get(SparkTextClassifierModelCache.java:34)

to save the model I use : crossValidatorModel.save("/tmp/my.model")

and to load it I use : CrossValidatorModel.load("/tmp/my.model")

I call save on the CrossValidatorModel object I get when I call fit(dataframe) on the CrossValidator object.

Any pointer why it only saves the metadata directory?


Solution

  • This will certainly not answer your question directly, but personally I didn't test the new feature in 1.6.0.

    I am using a dedicated function to save the models.

      def saveCrossValidatorModel(model:CrossValidatorModel, path:String)
      {
        try {
              val fileOut:FileOutputStream  = new FileOutputStream(path)
              val out:ObjectOutputStream  = new ObjectOutputStream(fileOut)
              out.writeObject(model)
              out.close()
              fileOut.close()
          } catch {
            case foe:FileNotFoundException =>
              foe.printStackTrace()
            case ioe:IOException =>
              ioe.printStackTrace()
          }
      }
    

    And you can then read your model in a similar way:

      def loadCrossValidatorModel(path:String): CrossValidatorModel =
      {
        try {
          val fileIn:FileInputStream = new FileInputStream(path)
          val in:ObjectInputStream  = new ObjectInputStream(fileIn)
          val cvModel = in.readObject().asInstanceOf[CrossValidatorModel]
          in.close()
          fileIn.close()
          cvModel
        } catch {
            case foe:FileNotFoundException =>
              foe.printStackTrace()
            case ioe:IOException =>
              ioe.printStackTrace()
          }
      }