Search code examples
apache-sparkpysparkapache-spark-mllib

Spark ML Naive Bayes predict multiple classes with probabilities


Is there a way to let the model to return a list of prediction labels with the probability score for each label?

For example given feature (f1,f2,f3), it returns something like this: label1:0.50,label2:0.33...

Is it doable in spark?


Solution

  • Yes it is possible. The output from rawPrediction column is an Array[Double] which contains the probability for each label.

    In your example this column would be an Array(0.5,0.33,0.17), you will have to write an UDF that transforms this Array into a String.

    It is important to note that if you used a StringIndexer to encode your label column the resulting labels will be different from your original ones. (most frequent label gets index 0)

    Had some code that does something similar which can be adapted to your use case. My code just writes the top X predictions for each feature as a CSV file. parameter @df for writeToCsv must be a DataFrame after it has been transformed by your Naive Bayes model.

     def topXPredictions(v: Vector, labels: Broadcast[Array[String]], topX: Int): Array[String] = {
        val labelVal = labels.value
        v.toArray
          .zip(labelVal)
          .sortBy {
            case (score, label) => score
          }
          .reverse
          .map {
            case (score, label) => label
          }
          .take(topX)
      }
    
      def writeToCsv(df: DataFrame, labelsBroadcast: Broadcast[Array[String]], name: String = "output"): Unit = {
        val get_top_predictions = udf((v: Vector, x: Int) => topXPredictions(v, labelsBroadcast, x))
    
          df
          .select(
            col("id")
            ,concat_ws(" ", get_top_predictions(col("rawPrediction"), lit(10))).alias("top10Predictions")
           )
          .orderBy("id")
          .coalesce(1)
          .write
          .mode(SaveMode.Overwrite)
          .format("com.databricks.spark.csv")
          .option("header", "true")
          .save(name)
      }