Search code examples
pythonpandasmulti-level

Pandas reshape dataframe by adding a column level based on the value of another column


I have a pandas dataframe and I would like to add a column level to split specific columns (metric_a, metric_b, metric_c) into several subcolumns based on the value of another column (parameter).


Current data format:

    participant param   metric_a    metric_b    metric_c
0   alice       a       0,700       0,912       0,341
1   alice       b       0,736       0,230       0,370
2   bob         a       0,886       0,364       0,995
3   bob         b       0,510       0,704       0,990
4   charlie     a       0,173       0,462       0,709
5   charlie     b       0,085       0,950       0,807
6   david       a       0,676       0,653       0,189
7   david       b       0,823       0,524       0,430

Wanted data format:

    participant metric_a        metric_b        metric_c
                a       b       a       b       a       b
0   alice       0,700   0,736   0,912   0,230   0,341   0,370
1   bob         0,886   0,510   0,364   0,704   0,995   0,990
2   charlie     0,173   0,085   0,462   0,950   0,709   0,807
3   david       0,676   0,823   0,653   0,524   0,189   0,430

I have tried

df.set_index(['participant', 'param']).unstack(['param'])

which gives me a close result but not satisfies me as I want to keep a single-level index and participant a regular column.

            metric_a        metric_b        metric_c
param       a       b       a       b       a       b
participant
alice       0,700   0,736   0,912   0,230   0,341   0,370
bob         0,886   0,510   0,364   0,704   0,995   0,990
charlie     0,173   0,085   0,462   0,950   0,709   0,807
david       0,676   0,823   0,653   0,524   0,189   0,430

I have the intuition that groupby() or pivot_table() functions could do the job but cannot figure out how.


Solution

  • IIUC, use DataFrame.set_index and unstack, and reset_index specifying col_level parameter:

    df.set_index(['participant', 'param']).unstack('param').reset_index(col_level=0)
    

    [out]

          participant metric_a        metric_b        metric_c       
    param                    a      b        a      b        a      b
    0           alice    0,700  0,736    0,912  0,230    0,341  0,370
    1             bob    0,886  0,510    0,364  0,704    0,995  0,990
    2         charlie    0,173  0,085    0,462  0,950    0,709  0,807
    3           david    0,676    NaN    0,653    NaN    0,189    NaN
    4           heidi      NaN  0,823      NaN  0,524      NaN  0,430