Search code examples
apache-sparkpysparkapache-spark-mllibapache-spark-ml

Change PySpark StringIndexer input_col param when wrapped in a Pipeline object


I'm building a Pipeline object to encode my category column using a StringIndexer object.

indexers = [StringIndexer(inputCol='FirstName',
                                  outputCol='FirstName_new',
                                  handleInvalid='keep',
                                  stringOrderType='frequencyDesc').fit(df)]

pipeline = Pipeline(stages=indexers)

pipeline.write().overwrite().save(path)

I want to use the same pipeline object but on another column (I have a specific use-case which I need it). Is there any way I can change the input_col parameter?


Solution

  • You can use the setInputCol method to set change the input column name.

    indexers = [StringIndexer(inputCol='FirstName',
                                      outputCol='FirstName_new',
                                      handleInvalid='keep',
                                      stringOrderType='frequencyDesc')]
    
    pipeline = Pipeline(stages=indexers)
    
    >>> print(pipeline.getStages()[0].getInputCol())
    FirstName
    
    pipeline.getStages()[0].setInputCol('test')
    
    >>> print(pipeline.getStages()[0].getInputCol())
    'test'
    

    Note that you should not put fit(df) inside the pipeline - you should fit to the data using the pipeline, e.g. pipeline.fit(df).