Search code examples
apache-sparkpysparkapache-spark-sqlcross-validationapache-spark-ml

Issue in Pyspark Cross Validation


I'm trying to cross validate RF model on Pyspark in the code below and is throwing error :

from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

# Your code 
trainData = raw_data_ 
numFolds = 5 

rf = RandomForestClassifier(labelCol="Target", featuresCol="Scaled_features")
evaluator = MulticlassClassificationEvaluator() #    

pipeline = Pipeline(stages=[rf])
paramGrid = (ParamGridBuilder()\
    .addGrid(rf.numTrees, [3, 10])\
    .build())
crossval = CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=paramGrid,
    evaluator=evaluator,
    numFolds=numFolds)

tr_model = crossval.fit(trainData)

But this is resulting in an error

Error

My raw_data_ variable is :

|            features|Position_Group|     Scaled_features|Target|
+--------------------+--------------+--------------------+------+
|[173.735992431640...|           FWD|[12.9261366722264...|     0|
|[188.975997924804...|           FWD|[14.0600087682323...|     0|
|[179.832000732421...|           FWD|[13.3796859647366...|     0|
|[155.752807617187...|           MID|[11.5881692110224...|     2|
|[176.783996582031...|           FWD|[13.1529113184815...|     0|
|[176.783996582031...|           MID|[13.1529113184815...|     2|
|[182.880004882812...|           FWD|[13.6064606109917...|     0|
|[182.880004882812...|           DEF|[13.6064606109917...|     1|
|[182.880004882812...|           FWD|[13.6064606109917...|     0|
|[182.880004882812...|           MID|[13.6064606109917...|     2|
|[188.975997924804...|           DEF|[14.0600087682323...|     1|
|[176.783996582031...|           MID|[13.1529113184815...|     2|
|[170.688003540039...|           MID|[12.6993631612409...|     2|
|[155.447998046875...|           FWD|[11.5654910652351...|     0|
|[188.975997924804...|           FWD|[14.0600087682323...|     0|
|[179.832000732421...|           MID|[13.3796859647366...|     2|
|[188.975997924804...|           MID|[14.0600087682323...|     2|
|[185.927993774414...|           FWD|[13.8332341219772...|     0|
|[176.783996582031...|           FWD|[13.1529113184815...|     0|
|[188.975997924804...|           DEF|[14.0600087682323...|     1|
+--------------------+--------------+--------------------+------+

Any suggestions on why and where the issue is happening? How can the issue be resolved?

Thanks


Solution

  • The error says

    Error while calling evaluate. Field "label" does not exist.

    which suggests that something's wrong with the evaluator. In your definition of the evaluator, you did not specify the label column, so the evaluator attempts to use the default "label" column, but that does not exist.

    To resolve this, you need to specify the label column when instantiating the evaluator, just as what you did for the classifier. e.g.

    evaluator = MulticlassClassificationEvaluator(labelCol="Target")