Search code examples
pythonpandasnumpy

How to convert the column with lists into one hot encoded columns?


Assume, there is one DataFrame such as following

import pandas as pd 
import numpy as np 

df = pd.DataFrame({'id':range(1,4), 
                   'items':[['A', 'B'], ['A', 'B', 'C'], ['A', 'C']]})
df
        id  items
        1   [A, B]
        2   [A, B, C]
        3   [A, C]

Is there an efficient way to convert above DataFrame into the following (one-hot encoded columns)? Many Thanks in advance!

   id   items       A   B   C
    1   [A, B]      1   1   0
    2   [A, B, C]   1   1   1
    3   [A, C]      1   0   1

Solution

  • SOLUTION 1

    A possible solution, whose steps are:

    • First, the explode function is used to transform each item of a list-like to a row, replicating the index values.

    • Then, the to_numpy method converts the resulting dataframe to a numpy array, and .T transposes this array.

    • The crosstab function computes a simple cross-tabulation of factors, which, in this case, are the transposed columns of the exploded dataframe.

    • The reset_index method is used to reset the index of the dataframe, turning the index into a column named id.

    • Finally, the original dataframe df is merged with this transformed dataframe using the merge function.

    df.merge(
        pd.crosstab(*df.explode('items').to_numpy().T)
        .reset_index(names='id'))
    

    SOLUTION 2

    Another possible solution, whose steps are:

    • First, the explode function is used to transform each item of a list-like to a row, replicating the index values.

    • Then, the pivot_table function is applied to reshape the data based on the unique values in the items column, aggregating the count of each id for every item. The fill_value=0 ensures that any missing combinations are filled with zeros.

    • The rename_axis method is used to remove the axis name for the columns.

    • Finally, reset_index is called to reset the index of the dataframe, turning the index into a column.

    • The original dataframe df is then merged with this transformed dataframe using the merge function.

    df.merge(
        df.explode('items')
        .pivot_table(index='id', columns='items', values='id', aggfunc=len, 
                     fill_value=0)
        .rename_axis(None, axis=1).reset_index())
    

    Output:

       id      items  A  B  C
    0   1     [A, B]  1  1  0
    1   2  [A, B, C]  1  1  1
    2   3     [A, C]  1  0  1