Search code examples
apache-sparkrandom-forestcross-validationapache-spark-mlapache-spark-mllib

How to cross validate RandomForest model?


I want to evaluate a random forest being trained on some data. Is there any utility in Apache Spark to do the same or do I have to perform cross validation manually?


Solution

  • ML provides CrossValidator class which can be used to perform cross-validation and parameter search. Assuming your data is already preprocessed you can add cross-validation as follows:

    import org.apache.spark.ml.Pipeline
    import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
    import org.apache.spark.ml.classification.RandomForestClassifier
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    
    // [label: double, features: vector]
    trainingData org.apache.spark.sql.DataFrame = ??? 
    val nFolds: Int = ???
    val numTrees: Int = ???
    val metric: String = ???
    
    val rf = new RandomForestClassifier()
      .setLabelCol("label")
      .setFeaturesCol("features")
      .setNumTrees(numTrees)
    
    val pipeline = new Pipeline().setStages(Array(rf)) 
    
    val paramGrid = new ParamGridBuilder().build() // No parameter search
    
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      // "f1" (default), "weightedPrecision", "weightedRecall", "accuracy"
      .setMetricName(metric) 
    
    val cv = new CrossValidator()
      // ml.Pipeline with ml.classification.RandomForestClassifier
      .setEstimator(pipeline)
      // ml.evaluation.MulticlassClassificationEvaluator
      .setEvaluator(evaluator) 
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(nFolds)
    
    val model = cv.fit(trainingData) // trainingData: DataFrame
    

    Using PySpark:

    from pyspark.ml import Pipeline
    from pyspark.ml.classification import RandomForestClassifier
    from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
    from pyspark.ml.evaluation import MulticlassClassificationEvaluator
    
    trainingData = ... # DataFrame[label: double, features: vector]
    numFolds = ... # Integer
    
    rf = RandomForestClassifier(labelCol="label", featuresCol="features")
    evaluator = MulticlassClassificationEvaluator() # + other params as in Scala    
    
    pipeline = Pipeline(stages=[rf])
    paramGrid = (ParamGridBuilder. 
        .addGrid(rf.numTrees, [3, 10])
        .addGrid(...)  # Add other parameters
        .build())
    
    crossval = CrossValidator(
        estimator=pipeline,
        estimatorParamMaps=paramGrid,
        evaluator=evaluator,
        numFolds=numFolds)
    
    model = crossval.fit(trainingData)