Search code examples
scalacross-validationk-fold

Get individual model scores at every iteration / fold in k-fold validation


I am trying to perform kfold validation in scala. I am using a random forest model and rmse as an evaluator. I can get the rmse values only for the best model.

Code:

val rf = new RandomForestRegressor().setLabelCol("label").setFeaturesCol("features").setNumTrees(2).setMaxDepth(2)
val paramGrid = new ParamGridBuilder().build()
val evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("label").setPredictionCol("prediction")

val cv = new CrossValidator().setEstimator(pipeline).setEvaluator(evaluator).setEstimatorParamMaps(paramGrid).setNumFolds(2).setParallelism(2) 

val cvModel = cv.fit(trainingValDf)

I would like to print the individual rmse values in the validation phase.

Eg:

(1, 4.3)

(2, 4.4)

(3, 4.2)

.

.

.

(k, rmse for that iteration)

Please let me know how to do this in Scala. Thanks!


Solution

  • The cross validator calculates metric per param map with the following code snippet;

    Spark Cross Validation

    As you see from higlighted(yellow) fields intermediate metrics are not stored anywhere, only average is accessible but you can print the information you desired by manipulating log levels(underlined).

    Logger.getLogger("org.apache.spark").setLevel(Level.OFF)
    Logger.getLogger("org.apache.spark.ml.util").setLevel(Level.DEBUG)
    

    The code snippet above will turn off all spark logs and enable only the util package logs (the logs you've desired are printed by Intrumentation object which is located in util package). This will generate following output;

    Result

    Yet the ordering is not what you asked; it goes like for first split it calculates all possible parameters then moves to second split. If you want exactly the output you've asked the solution is extending the CrossValidator by defining a CustomCV class and overriding fit method such as CustomCV for testing purposes(I am printing per input configuration each k with rmse). Results in;

    CustomCV Results