Search code examples
random-forestapache-spark-mllib

How to get only the predictions with a probability greater than x


I used a random forest to classify texts to certain categories. When I used my testdata I got an accuracy of 0.98. But with another set of data the overall accuracy decreases to 0.7. I think, most of the rows still have a high accuracy.

So now I want to show only the predicted categories with a high confidence. random-forrest gives me a column "probability", which is an array of probabilities. How do I get the actual probabilty of the chosen prediction?

val randomForrest = new RandomForestClassifier()
      .setLabelCol(labelIndexer.getOutputCol)
      .setFeaturesCol(vectorAssembler.getOutputCol)
      .setProbabilityCol("probability")
      .setSeed(123)
      .setPredictionCol("prediction")

Solution

  • I eventually came up with the following udf to get the best prediction together with its probability. If there is a more convenient way, please comment.

    def getBestPrediction = udf((
      rawPrediction: org.apache.spark.ml.linalg.Vector, probability: org.apache.spark.ml.linalg.Vector) => {
      val bestPrediction = probability.argmax
      val bestProbability = probability(bestPrediction)     
      (bestPrediction, bestProbability)
    })