Search code examples
pythonpandasmulti-index

Pandas Multi Index column names


How to create a simple dataframe with muli index columns?

import pandas as pd

quality = ['strong', 'weak']
types = ['major', 'minor', 'unknown']


import itertools
a = [quality, types]
all_columns = list(itertools.product(*a))

list_ids = ['a1','b2','c3']
df = pd.DataFrame(list_ids, columns = ['ids'])

for col in all_columns:
    df[col] = np.nan

now I try to apply MultiIndex, but it does not change anything:

df.columns = pd.MultiIndex.from_frame(pd.DataFrame(df.columns))
display(df['strong','major'])

KeyError: 'Key length (2) exceeds index depth (1)'


Solution

  • Try concat + MultiIndex.from_product instead:

    result_df = pd.concat(
        (df, pd.DataFrame(columns=pd.MultiIndex.from_product([quality, types]))),
        axis=1
    )
    

    result_df:

      ids (strong, major) (strong, minor) (strong, unknown) (weak, major) (weak, minor) (weak, unknown)
    0  a1             NaN             NaN               NaN           NaN           NaN             NaN
    1  b2             NaN             NaN               NaN           NaN           NaN             NaN
    2  c3             NaN             NaN               NaN           NaN           NaN             NaN
    

    result_df['strong', 'major']:

    0    NaN
    1    NaN
    2    NaN
    Name: (strong, major), dtype: object
    

    import pandas as pd
    
    list_ids = ['a1', 'b2', 'c3']
    df = pd.DataFrame(list_ids, columns=['ids'])
    
    quality = ['strong', 'weak']
    types = ['major', 'minor', 'unknown']
    
    result_df = pd.concat(
        (df, pd.DataFrame(columns=pd.MultiIndex.from_product([quality, types]))),
        axis=1
    )