Search code examples

Writing the model output to a text file spark scala

I fitted following logistic regression model using spark MLlib

val df ="header","true").option("inferSchema","true").csv("car_milage-6f50d.csv")
val hasher = new FeatureHasher().setInputCols(Array("mpg","displacement","hp","torque")).setOutputCol("features")
val transformed = hasher.transform(df)
val Array(training, test) = transformed.randomSplit(Array(0.8, 0.2))
val lr = new LogisticRegression()
val paramGrid = new ParamGridBuilder()
  .addGrid(lr.regParam, Array(0.1,0.3))
  .addGrid(lr.elasticNetParam, Array(0.9,1))
val cv = new CrossValidator()
  .setEvaluator(new BinaryClassificationEvaluator())

val model =
val results = model.transform(test).select("features", "automatic", "prediction")

val predictionAndLabels ="prediction","label").as[(Double, Double)].rdd

At the end i obtained these model evaluation metrics

val mMetrics = new MulticlassMetrics(predictionAndLabels)

As the file step I need to write these evaluation metrics (mMetrics) into a file (can be a text file of a csv file) . Can anyone help me how to do that ?

I just tried and i couldn't find any write method which associated with these values.

Thank you


  • From looking at MultiClassMetrics's method summary I think you should be able to do it this way:

    val confusionMatrixOutput = mMetrics.confusionMatrix.toArray
    val confusionMatrixOutputFinal = spark.parallelize(confusionMatrixOutput)

    You should be able to do the same with mMetrics.labels:

    val labelsOutput = mMetrics.labels
    val labelsOutputFinal = spark.parallelize(labelsOutput)

    And accuracy should just be a double so you can just easily print this:

    val accuracy = mMetrics.accuracy
    println("Summary Statistics")
    println(s"Accuracy = $accuracy")

    You should be able to write all the statistics, for your logistic regression model, out to a single file like this:

      object MulticlassMetricsOutputWriter {
      def main(args:Array[String]) {
        // All your other code can be added here
        val mMetrics = new MulticlassMetrics(predictionAndLabels)
        val labels = mMetrics.labels
        // Create new file and passing reference of file to the printWriter
        val pw = new PrintWriter(new File("C:/mllib_lr_output.txt"))
        // Confusion Matrix
        val confusionMatrixOutput = mMetrics.confusionMatrix.toArray
        val confusionMatrixOutputFinal = spark.parallelize(confusionMatrixOutput)
        // Labels
        val labelsOutput = mMetrics.labels
        val labelsOutputFinal = spark.parallelize(labelsOutput)
        // False positive rate by label
        labels.foreach { l =>
          pw.write(s"FPR($l) = " + mMetrics.falsePositiveRate(l) + "\n")
        // True positive rate by label
        labels.foreach { l =>
          pw.write(s"TPR($l) = " + mMetrics.truePositiveRate(l) + "\n")
        // F-measure by label
        labels.foreach { l =>
          pw.write(s"F1-Score($l) = " + mMetrics.fMeasure(l) + "\n")
        // Precision by label
        labels.foreach { l =>
          pw.write(s"Precision($l) = " + mMetrics.precision(l) + "\n")
        // Recall by label
        labels.foreach { l =>
          pw.write(s"Recall($l) = " + mMetrics.recall(l) + "\n")
        val accuracy = mMetrics.accuracy
        val weightedFalsePositiveRate = mMetrics.weightedFalsePositiveRate
        val weightedFMeasure = mMetrics.weightedFMeasure
        val weightedPrecision = mMetrics.weightedPrecision
        val weightedRecall = mMetrics.weightedRecall
        val weightedTruePositiveRate = mMetrics.weightedTruePositiveRate
        pw.write("Summary Statistics" + "\n")
        pw.write(s"Accuracy = $accuracy" + "\n")
        pw.write(s"weightedFalsePositiveRate = $weightedFalsePositiveRate" + "\n")
        pw.write(s"weightedFMeasure = $weightedFMeasure" + "\n")
        pw.write(s"weightedPrecision = $weightedPrecision" + "\n")
        pw.write(s"weightedRecall = $weightedRecall" + "\n")
        pw.write(s"weightedTruePositiveRate = $weightedTruePositiveRate" + "\n")
        // Closing the printWriter connection