Search code examples
scalaapache-sparkdataframeapache-spark-sqlapache-spark-ml

Cannot run RandomForestClassifier from spark ML on a simple example


I have tried to run the experimental RandomForestClassifier from the spark.ml package (version 1.5.2). The dataset I used is from the LogisticRegression example in the Spark ML guide.

Here is the code:

import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.sql.Row

// Prepare training data from a list of (label, features) tuples.
val training = sqlContext.createDataFrame(Seq(
  (1.0, Vectors.dense(0.0, 1.1, 0.1)),
  (0.0, Vectors.dense(2.0, 1.0, -1.0)),
  (0.0, Vectors.dense(2.0, 1.3, 1.0)),
  (1.0, Vectors.dense(0.0, 1.2, -0.5))
)).toDF("label", "features")

val rf = new RandomForestClassifier()

val model = rf.fit(training)

And here is the error, I obtain:

java.lang.IllegalArgumentException: RandomForestClassifier was given input with invalid label column label, without the number of classes specified. See StringIndexer.
    at org.apache.spark.ml.classification.RandomForestClassifier.train(RandomForestClassifier.scala:87)
    at org.apache.spark.ml.classification.RandomForestClassifier.train(RandomForestClassifier.scala:42)
    at org.apache.spark.ml.Predictor.fit(Predictor.scala:90)
    at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:48)
    at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:53)
    at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:55)
    at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:57)
    at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:59)
    at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:61)
    at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:63)
    at $iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:65)
    at $iwC$$iwC$$iwC$$iwC.<init>(<console>:67)
    at $iwC$$iwC$$iwC.<init>(<console>:69)
    at $iwC$$iwC.<init>(<console>:71)
    at $iwC.<init>(<console>:73)
    at <init>(<console>:75)
    at .<init>(<console>:79)
    at .<clinit>(<console>)
    at .<init>(<console>:7)
    at .<clinit>(<console>)
    at $print(<console>)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:497)
    at org.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:1065)
    at org.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1340)
    at org.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:840)
    at org.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:871)
    at org.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:819)
    at org.apache.spark.repl.SparkILoop.reallyInterpret$1(SparkILoop.scala:857)
    at org.apache.spark.repl.SparkILoop.interpretStartingWith(SparkILoop.scala:902)
    at org.apache.spark.repl.SparkILoop.command(SparkILoop.scala:814)
    at org.apache.spark.repl.SparkILoop.processLine$1(SparkILoop.scala:657)
    at org.apache.spark.repl.SparkILoop.innerLoop$1(SparkILoop.scala:665)
    at org.apache.spark.repl.SparkILoop.org$apache$spark$repl$SparkILoop$$loop(SparkILoop.scala:670)
    at org.apache.spark.repl.SparkILoop$$anonfun$org$apache$spark$repl$SparkILoop$$process$1.apply$mcZ$sp(SparkILoop.scala:997)
    at org.apache.spark.repl.SparkILoop$$anonfun$org$apache$spark$repl$SparkILoop$$process$1.apply(SparkILoop.scala:945)
    at org.apache.spark.repl.SparkILoop$$anonfun$org$apache$spark$repl$SparkILoop$$process$1.apply(SparkILoop.scala:945)
    at scala.tools.nsc.util.ScalaClassLoader$.savingContextLoader(ScalaClassLoader.scala:135)
    at org.apache.spark.repl.SparkILoop.org$apache$spark$repl$SparkILoop$$process(SparkILoop.scala:945)
    at org.apache.spark.repl.SparkILoop.process(SparkILoop.scala:1059)
    at org.apache.spark.repl.Main$.main(Main.scala:31)
    at org.apache.spark.repl.Main.main(Main.scala)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:497)
    at org.apache.spark.deploy.SparkSubmit$.org$apache$spark$deploy$SparkSubmit$$runMain(SparkSubmit.scala:674)
    at org.apache.spark.deploy.SparkSubmit$.doRunMain$1(SparkSubmit.scala:180)
    at org.apache.spark.deploy.SparkSubmit$.submit(SparkSubmit.scala:205)
    at org.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:120)
    at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala)

The problem appears when the function tries to compute the number of classes in the column "label".

As you can see at line 84 in the source code of RandomForestClassifier, the function calls the DataFrame.schema function with parameter "label". This call is OK and returns a org.apache.spark.sql.types.StructField object. Then, the function org.apache.spark.ml.util.MetadataUtils.getNumClasses is called. As it does not return the expected output, an exception is raised at line 87.

After a quick glance at getNumClasses source code, I suppose that the error is due to the fact that the data in colmun "label" are neither BinaryAttribute neither NominalAttribute. However, I do not know how to fix this problem.

My question:

How can I fix this problem?

Thanks a lot for reading my question and for your help!


Solution

  • Let's first fix the import to remove ambiguity

    import org.apache.spark.ml.classification.RandomForestClassifier
    import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer}
    import org.apache.spark.ml.{Pipeline, PipelineStage}
    import org.apache.spark.ml.linalg.Vectors
    

    I'll use the same data you used :

    val training = sqlContext.createDataFrame(Seq(
      (1.0, Vectors.dense(0.0, 1.1, 0.1)),
      (0.0, Vectors.dense(2.0, 1.0, -1.0)),
      (0.0, Vectors.dense(2.0, 1.3, 1.0)),
      (1.0, Vectors.dense(0.0, 1.2, -0.5))
    )).toDF("label", "features")
    

    and then create Pipeline Stages :

    val stages = new scala.collection.mutable.ArrayBuffer[PipelineStage]()
    
    1. For classification, re-index classes :
    val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(training)
    
    1. Identify categorical features using VectorIndexer
    val featuresIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(10).fit(training)
    stages += featuresIndexer
    
    val tmp = featuresIndexer.transform(labelIndexer.transform(training))
    
    1. Learn Random Forest
    val rf = new RandomForestClassifier().setFeaturesCol(featuresIndexer.getOutputCol).setLabelCol(labelIndexer.getOutputCol)
    
    stages += rf
    val pipeline = new Pipeline().setStages(stages.toArray)
    
    // Fit the Pipeline
    val pipelineModel = pipeline.fit(tmp)
    
    val results = pipelineModel.transform(training)
    
    results.show
    
    //+-----+--------------+---------------+-------------+-----------+----------+
    //|label|      features|indexedFeatures|rawPrediction|probability|prediction|
    //+-----+--------------+---------------+-------------+-----------+----------+
    //|  1.0| [0.0,1.1,0.1]|  [0.0,1.0,2.0]|   [1.0,19.0]|[0.05,0.95]|       1.0|
    //|  0.0|[2.0,1.0,-1.0]|  [1.0,0.0,0.0]|   [17.0,3.0]|[0.85,0.15]|       0.0|
    //|  0.0| [2.0,1.3,1.0]|  [1.0,3.0,3.0]|   [14.0,6.0]|  [0.7,0.3]|       0.0|
    //|  1.0|[0.0,1.2,-0.5]|  [0.0,2.0,1.0]|   [1.0,19.0]|[0.05,0.95]|       1.0|
    //+-----+--------------+---------------+-------------+-----------+----------+
    

    References: Concerning the step 1. and 2. for those who want more details on Feature transformers, I advice you to read the official documentation here.