Search code examples
apache-sparkmachine-learningpysparkapache-spark-sqlone-hot-encoding

One hot encoder Pyspark for multiple columns with each columns having different number of categorical labels


I'm new to pyspark and I need to display all unique labels that are present in different categorical columns

I have a pyspark dataframe with the following columns

ID cat_col1 cat_col2
1 A C
2 B A

I want the final output to look like

ID cat_col1_A cat_col1_B cat_col1_C cat_col2_A cat_col2_B cat_col2_C
1 1 0 0 0 0 1
2 0 1 0 1 0 0
categorical_columns = ['cat_col1 ','cat_col2']
# The index of string values multiple columns
indexers = [
    StringIndexer(inputCol=c, outputCol="{0}_indexed".format(c)).setHandleInvalid("keep")
    for c in categorical_columns
]

# The encode of indexed values multiple columns
encoders = [OneHotEncoder(dropLast=False,inputCol=indexer.getOutputCol(), outputCol="{0}_encoded".format(indexer.getOutputCol())) for indexer in indexers]


pipeline = Pipeline(stages=indexers + encoders) #+[assembler])
model=pipeline.fit(temp_df).transform(temp_df)
model.display()

The OP is something like this

ID Cat_col_1_A_indexed Cat_col_1_B_indexed Cat_col_1_unknown_indexed Cat_col_2_A_indexed Cat_col_2_C_indexed Cat_col_2_unknown_indexed

Only the unique labels in each column are displayed. I want to display the unique labels present in all categorical columns


Solution

  • i could think of a couple of ways, and both of them require unpivoting the data.

    • create distinct column names within the dataframe
    • create column names using distinct values from all categories (requires a collect)

    within dataframe

    data_sdf. \
        withColumn('attr', 
                   func.array(*[func.struct(func.lit(c).alias('c_name'), 
                                            func.col(c).alias('c_val')
                                            ) for c in data_sdf.drop('id').columns]
                              )
                   ). \
        selectExpr('id', 'inline(attr)'). \
        withColumn('all_val', func.collect_set('c_val').over(wd.partitionBy(func.lit(1)))). \
        select('*', func.explode('all_val').alias('all_val_exp')). \
        withColumn('pivot_col', func.concat_ws('_', 'c_name', 'all_val_exp')). \
        groupBy('id'). \
        pivot('pivot_col'). \
        agg(func.max(func.col('c_val') == func.col('all_val_exp')).cast('int')). \
        show()
    
    # +---+----------+----------+----------+----------+----------+----------+
    # | id|cat_col1_A|cat_col1_B|cat_col1_C|cat_col2_A|cat_col2_B|cat_col2_C|
    # +---+----------+----------+----------+----------+----------+----------+
    # |  1|         1|         0|         0|         0|         0|         1|
    # |  2|         0|         1|         0|         1|         0|         0|
    # +---+----------+----------+----------+----------+----------+----------+
    

    using list to create required column names

    cats = data_sdf.select(func.array_distinct(func.flatten(func.collect_list(func.array('cat_col1', 'cat_col2'))))).collect()[0][0]
    
    fnl_cat_cols = sorted([x+'_'+y for x in data_sdf.drop('id').columns for y in cats])
    
    # ['cat_col1_A',
    #  'cat_col1_B',
    #  'cat_col1_C',
    #  'cat_col2_A',
    #  'cat_col2_B',
    #  'cat_col2_C']
    
    data_sdf. \
        withColumn('attr', 
                   func.array(*[func.struct(func.lit(c).alias('c_name'), 
                                            func.col(c).alias('c_val')
                                            ) for c in data_sdf.drop('id').columns]
                              )
                   ). \
        selectExpr('id', 'inline(attr)'). \
        withColumn('pivot_col', func.concat_ws('_', 'c_name', 'c_val')). \
        groupBy('id'). \
        pivot('pivot_col', values=fnl_cat_cols). \
        agg(func.lit(1)). \
        fillna(0, subset=fnl_cat_cols). \
        show()
    
    # +---+----------+----------+----------+----------+----------+----------+
    # | id|cat_col1_A|cat_col1_B|cat_col1_C|cat_col2_A|cat_col2_B|cat_col2_C|
    # +---+----------+----------+----------+----------+----------+----------+
    # |  1|         1|         0|         0|         0|         0|         1|
    # |  2|         0|         1|         0|         1|         0|         0|
    # +---+----------+----------+----------+----------+----------+----------+
    

    the first approach seems memory intensive given the 2 explodes. so, you can try the second approach which seems lighter of the two.