Search code examples
apache-spark-ml

How to retrieve the CrossValidator bestModel ALS regParam value?


I'm training an ALS model with a CrossValidator:

  val als = new ALS()
    .setMaxIter(5)
    .setUserCol("userId")
    .setItemCol("movieId")
    .setRatingCol("rating")

  val evaluator = new RegressionEvaluator()
    .setMetricName("rmse")
    .setLabelCol("rating")
    .setPredictionCol("prediction")

  val paramGrid = new ParamGridBuilder()
    .addGrid(als.regParam, Array(0.001, 0.01, 0.1, 1))
    .build()   

  val cv = new CrossValidator()
    .setEstimator(als)
    .setEvaluator(evaluator)
    .setEstimatorParamMaps(paramGrid)
    .setNumFolds(3)

  val cvModel = cv.fit(training)

I would like to inspect the chosen regParam value. I've tried this:

  val bestRegParam = cvModel.bestModel.getRegParam()

However, I get the exception:

value getRegParam is not a member of org.apache.spark.ml.Model[_$5]


Solution

  • usually you have to cast bestModel to a specific model, e.g. ALSModel. But ALSModel class doesn't have regParam field. Unfortunately I guess there is no way to extract the chosen regParam and it's really a question to Spark developers.

    You could turn on logging for CrossValidator as it logs the chosen best set of parameters