Search code examples
pythonencodingpysparkfeature-extraction

How to do mean(target) encoding in pyspark


I need to do a mean(target) encoding to all categorical columns in my dataset. To simplify this problem, Let's say there're 2 columns in my dataset, first column is the label column, the second column is a categorical column.

e.g

label | cate1   
  0   |  abc    
  1   |  abc    
  0   |  def    
  0   |  def    
  1   |  ghi

So according to mean encoding strategy: https://towardsdatascience.com/why-you-should-try-mean-encoding-17057262cd0

the output should be like

label | cate1    
  0   |  0.5   
  1   |  0.5    
  0   |  0.0    
  0   |  0.0    
  1   |  1.0

I've tried Koalas to solve this problem, but failed. This is what I've tried:

for col_name in convert_cols:


    cat_mean_dict = dict()
    # get category name <-> count dictionary
    cur_col_cate_count_ = ks_df[col_name].value_counts().to_dict()
    print(cur_col_cate_count_)

    # calculate all different categories positive result count and mean value
    start_time = time.time()
    for key in cur_col_cate_count_:

        current_col_positive_count = ks_df.loc[(ks_df['0'] == 1) & (ks_df[col_name] == key)].shape[0]
        key_mean = current_col_positive_count / cur_col_cate_count_[key]
        cat_mean_dict[key] = key_mean


    for i in range(ks_df.shape[0]):
        cate_origin_hash = ks_df.at[i, col_name]
        if cate_origin_hash in cat_mean_dict:
            ks_df.at[i, col_name] = cat_mean_dict[cate_origin_hash]
        else:
            ks_df.at[i, col_name] = -1

But Koalas doesn't allow cell-level update, meaning I can't modify the value by ks_df.at[i, col_name] = new_value

So I'm hoping there could be some pyspark solution to this problem.


Solution

  • Please find below a pyspark solution:

    # spark inputs
    spark_data = [Row(label=0, cate1='abc'),
                  Row(label=1, cate1='abc'),
                  Row(label=0, cate1='def'),
                  Row(label=0, cate1='def'),
                  Row(label=1, cate1='ghi')]
    
    df = spark.createDataFrame(spark_data)
    
    df.show()
    >>>
    +-----+-----+
    |cate1|label|
    +-----+-----+
    |  abc|    0|
    |  abc|    1|
    |  def|    0|
    |  def|    0|
    |  ghi|    1|
    +-----+-----+
    
    
    # function
    def target_mean_encoding(df, col, target):
        """
        :param df: pyspark.sql.dataframe
            dataframe to apply target mean encoding
        :param col: str list
            list of columns to apply target encoding
        :param target: str
            target column
        :return:
            dataframe with target encoded columns
        """
        target_encoded_columns_list = []
        for c in col:
            means = df.groupby(F.col(c)).agg(F.mean(target).alias(f"{c}_mean_encoding"))
            dict_ = means.toPandas().to_dict()
            target_encoded_columns = [F.when(F.col(c) == v, encoder)
                                      for v, encoder in zip(dict_[c].values(),
                                                            dict_[f"{c}_mean_encoding"].values())]
            target_encoded_columns_list.append(F.coalesce(*target_encoded_columns).alias(f"{c}_mean_encoding"))
        return df.select(target, *target_encoded_columns_list)
    
    
    # function apply on spark inputs
    df_target_encoded = target_mean_encoding(df, col=['cate1'], target='label')
    
    df_target_encoded.show()
    >>> 
    +-----+-------------------+
    |label|cate1_mean_encoding|
    +-----+-------------------+
    |    0|                0.5|
    |    1|                0.5|
    |    0|                0.0|
    |    0|                0.0|
    |    1|                1.0|
    +-----+-------------------+
    
    
    # if you want to keep the same column name after target mean encoder
    df_target_encoded.withColumnRenamed('cate1_mean_encoding', 'cate1')
    
    df_target_encoded.show()
    >>>
    +-----+-----+
    |label|cate1|
    +-----+-----+
    |    0|  0.5|
    |    1|  0.5|
    |    0|  0.0|
    |    0|  0.0|
    |    1|  1.0|
    +-----+-----+