Search code examples
apache-sparkapache-spark-sqlapache-spark-mllibapache-spark-ml

convert Seq[(String, Any)] to Seq[(String, org.apache.spark.ml.PredictionModel[_, _])] in spark


i had trained my dataset into different models such as nbModel, dtModel, rfModel, GbmModel . All these are machine learning models

now when i am saving it into a variable as

val models = Seq(("NB", nbModel), ("DT", dtModel), ("RF", rfModel), ("GBM",gbmModel))

i am getting a Seq[(String, Any)]

models: Seq[(String, Any)] = List((NB,NaiveBayesModel (uid=nb_c35f79982850) with 2 classes), (DT,()), (RF,RandomForestClassificationModel (uid=rfc_3f42daf4ea14) with 15 trees), (GBM,GBTClassificationModel (uid=gbtc_534a972357fa) with 20 trees))

if an individual model such as nbModel

 val models = ("NB", nbModel)

OUTPUT : models: (String, org.apache.spark.ml.classification.NaiveBayesModel) = (NB,NaiveBayesModel (uid=nb_c35f79982850) with 2 classes)

and when i am trying to merge few columns from those models i am getting type mismatch error

val mlTrainData= mlData(transferData, "value", models).drop("row_id")

<console>:75: error: type mismatch; found : Seq[(String, Any)] required: Seq[(String, org.apache.spark.ml.PredictionModel[_, _])] val mlTrainData= mlData(transferData, "value", models).drop("row_id")

Also my MlDATA is

def mlData(inputData: DataFrame, responseColumn: String, baseModels:
 | Seq[(String, PredictionModel[_, _])]): DataFrame= {
 | baseModels.map{ case(name, model) =>
 | model.transform(inputData)
 | .select("row_id", model.getPredictionCol )
 | .withColumnRenamed("prediction", s"${name}_prediction")
 | }.reduceLeft((a, b) =>a.join(b, Seq("row_id"), "inner"))
 | .join(inputData.select("row_id", responseColumn), Seq("row_id"),
 | "inner")
 | }

OUTPUT: mlData: (inputData: org.apache.spark.sql.DataFrame, responseColumn: String, baseModels: Seq[(String, org.apache.spark.ml.PredictionModel[_, _])])org.apache.spark.sql.DataFrame


Solution

  • Can you please replace the code

    val models = Seq(("NB", nbModel), ("DT", dtModel), ("RF", rfModel), ("GBM",gbmModel))
    

    by

    val models = Seq(("NB", nbModel), ("DT", null : org.apache.spark.mllib.tree.model.DecisionTreeModel), ("RF", rfModel), ("GBM",gbmModel))
    

    The point I am trying to make is, your dtModel is assigned () which is of type Unit. So the type of entire dataset becomes the superclass of DecisionTreeModel and Unit, which is Any. You need to make sure dtModel is of type DecisionTreeModel, it is okay if that's null, if you have handled the null case. An empty DecisionTreeModel would also work.