Search code examples
apache-sparkpysparkcross-validationapache-spark-ml

Does SparkML Cross Validation Only Work With a "label" Column?


When I am running the cross validation example with a dataset that has the label in a column not named "label" I am observing an IllegalArgumentException on Spark 3.1.1. Why?

The below code has been modified to rename "label" column into "target" and the labelCol has been set to "target" for the regression model. This code is causing the exception, while leaving everything at "label" works fine.

from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.feature import HashingTF, Tokenizer
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

training = spark.createDataFrame([
    (0, "a b c d e spark", 1.0),
    (1, "b d", 0.0),
    (2, "spark f g h", 1.0),
    (3, "hadoop mapreduce", 0.0),
    (4, "b spark who", 1.0),
    (5, "g d a y", 0.0),
    (6, "spark fly", 1.0),
    (7, "was mapreduce", 0.0),
    (8, "e spark program", 1.0),
    (9, "a e c l", 0.0),
    (10, "spark compile", 1.0),
    (11, "hadoop software", 0.0)
], ["id", "text", "target"]) # try switching between "target" and "label"

tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")

lr = LogisticRegression(maxIter=10, labelCol="target") #try switching between "target" and "label"

pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])

paramGrid = ParamGridBuilder() \
    .addGrid(hashingTF.numFeatures, [10, 100, 1000]) \
    .addGrid(lr.regParam, [0.1, 0.01]) \
    .build()

crossval = CrossValidator(estimator=pipeline,
                          estimatorParamMaps=paramGrid,
                          evaluator=BinaryClassificationEvaluator(),
                          numFolds=2)  


cvModel = crossval.fit(training)

Is that in any way expected behaviour?


Solution

  • You need to provide the label column to BinaryClassificationEvaluator too. So if you replace the line

    evaluator=BinaryClassificationEvaluator(),
    

    with

    evaluator=BinaryClassificationEvaluator(labelCol="target"),
    

    it should work fine.

    You can find the usage in the docs.