Search code examples
scalaapache-sparkapache-spark-ml

How can I make a function generic on an MLReader


I am working in Spark 1.6.3. Here are two functions that do the same thing:

def modelFromBytesCV(modelArray: Array[Byte]): CountVectorizerModel = {
  val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
  Files.write(tempPath, modelArray)
  CountVectorizerModel.read.load(tempPath.toString)
}

def modelFromBytesIDF(modelArray: Array[Byte]): IDFModel = {
  val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
  Files.write(tempPath, modelArray)
  IDFModel.read.load(tempPath.toString)
}

I would like to make these functions generic. What I am hung up on is that the common trait between the CountVectorizerModel object and IDFModel is MLReadable[T] which itself must take as a type either CountVectorizerModel or IDFModel. This is sort of a recursive parent class loop that I can't figure out a solution to.

By comparison, the generic model writer is easy, because MLWritable is a common trait extended by all the models I am interested in:

def modelToBytes[M <: MLWritable](model: M): Array[Byte] = {
  val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
  model.write.overwrite().save(tempPath.toString)
  Files.readAllBytes(tempPath)
}

How can I make a generic reader that will turn turn a spark-ml model into a byte array?


Solution

  • To make it work you'll need access to a specific MlReadable object.

    import org.apache.spark.ml.util.MLReadable
    
    def modelFromBytes[M](obj: MLReadable[M], modelArray: Array[Byte]): M = {
      val tempPath: Path = ???
      ...
      obj.read.load(tempPath.toString)
    }
    

    which could be later used as:

    val bytes: Array[Byte] = ???
    modelFromBytes(CountVectorizerModel, bytes)
    

    Note that, despite the first appearance, there is nothing recursive here - MLReadable[M] refers to companion object, not class as such. So for example CountVectorizerModel object is MLReadable, while CountVectorizeModel class isn't.

    Internally, Spark MLReader handles this in a different way - it creates an instance of the class using reflection, and then sets its Params. However this path won't be very useful for you here*.

    If compatibility with the current API is required, you can try making readable object implicit:

    def modelFromBytes[M](modelArray: Array[Byte])(implicit obj: MLReadable[M]): M = {
      ...
    }
    

    and then

    implicit val readable: MLReadable[CountVectorizerModel] = CountVectorizerModel
    
    modelFromBytes[CountVectorizerModel](bytes)
    

    * Technically speaking it is possible to get companion object via reflection

    def modelFromBytesCV[M <: MLWritable](
        modelArray: Array[Byte])(implicit ct: ClassTag[M]): M = {
      val tempPath: Path = ???
      ...
      val cls = Class.forName(ct.runtimeClass.getName + "$");
      cls.getField("MODULE$").get(cls).asInstanceOf[MLReadable[M]]
        .read.load(tempPath.toString)) 
    }
    

    but I don't think that is a path worth exploring here. In particular we cannot really provide strict type bounds here - using MLWritable is a hack to limit human errors, but is rather useless for compiler.