Search code examples
apache-sparkpysparkapache-spark-mllib

StringIndexer where category levels passed as list


StringIndexer seems to infer the indices based on the unique values in the data. This is a problem when the data does not have every possible value. The toy example below considers three t-shirt sizes (Small, Medium, and Large), but only two (Small and Large) are in the data. I would like the StringIndexer to still consider all 3 possible sizes. Is there some way to create a column using the index of a string in supplied list? It would be preferable to do it as a Transformer() so that it could be re-used in a pipeline.

from pyspark.sql import Row
df = spark.createDataFrame([Row(id='0', size='Small'),
                            Row(id='1', size='Small'),
                            Row(id='2', size='Large')])
(
    StringIndexer(inputCol="size", outputCol="size_idx")
    .fit(df)
    .transform(df)
    .show()
)
+---+-----+--------+
| id| size|size_idx|
+---+-----+--------+
|  0|Small|     0.0|
|  1|Small|     0.0|
|  2|Large|     1.0|
+---+-----+--------+

Desired output

+---+-----+--------+
| id| size|size_idx|
+---+-----+--------+
|  0|Small|     0.0|
|  1|Small|     0.0|
|  2|Large|     2.0|
+---+-----+--------+

Solution

  • It looks like you can create the StringIndexer model directly from a set of labels instead of fitting from the data.

    from pyspark.sql import Row
    from pyspark.ml.feature import StringIndexerModel
    
    df = spark.createDataFrame([Row(id='0', size='Small'),
                                Row(id='1', size='Small'),
                                Row(id='2', size='Large')])
    
    si = StringIndexerModel.from_labels(['Small', 'Medium', 'Large'],
                                        inputCol="size",
                                        outputCol="size_idx")
    
    si.transform(df).show()
    +---+-----+--------+
    | id| size|size_idx|
    +---+-----+--------+
    |  0|Small|     0.0|
    |  1|Small|     0.0|
    |  2|Large|     2.0|
    +---+-----+--------+