Search code examples
pythonpandasdataframegroup-byaggregation

groupby with dictionary comprehension


I have a dataframe

df = pd.DataFrame({'id':[1,2,3,1, 1], 'time_stamp_date':['12','12', '12', '14', '14'], 'sth':['col1','col1', 'col2','col2', 'col3']})

d   time_stamp_date sth
0   1   12         col1
1   2   12         col1
2   3   12         col2
3   1   14         col2
4   1   14         col3

and I would like to get the following dataframe. So for each column in sht_listI would like to check if it appears in sth for a specific id and time_stamp_date.

id  time_stamp_date sth col1 col2 col3  col4
0   1   12          col1    1   0    0    0
1   2   12          col1    1   0    0    0
2   3   12          col2    0   1    0    0
3   1   14          col2    0   1    1    0
4   1   14          col3    0   1    1    0

I can do it like this

df_out = df.assign(**{col: df.groupby(['id','time_stamp_date']).sth.transform(
                                                            lambda x: (x==col).any()).astype(int) 
                                      for col in sht_list})

but I would like to use only groupby (without transform and then having to use drop_duplicates()) but the below doesn't work because all columns will be named the same (). The error that I get is SpecificationError: Function names must be unique, found multiple named <lambda>.

df_out = df.groupby(['id','time_stamp_date'])[['sth']].agg({lambda x: (x==col).any().astype(int) 
                                                      for col in sht_list})

Is it possible to make the above code work?


Solution

  • Just need to think in terms of the Index for this one. You already know where the values for your dummies exist so you can rely on index manipulations and filling values that don't exist.

    import numpy as np
    import pandas as pd
    
    df = pd.DataFrame({
        'id':[1,2,3,1, 1],
        'time_stamp_date':['12','12', '12', '14', '14'],
        'sth':['col1','col1', 'col2','col2', 'col3']
    })
    sth_list = ['col1', 'col2', 'col3', 'col4']
    
    out = (
        pd.Series(data=1, index=pd.MultiIndex.from_frame(df))
        .unstack('sth', fill_value=0)
        .reindex(
            # index=… here is only if you want row order from original data
            index=pd.MultiIndex.from_frame(df[['id', 'time_stamp_date']].drop_duplicates()),
            columns=sth_list, fill_value=0
        )
    )
    print(out)
    sth                 col1  col2  col3  col4
    id time_stamp_date                        
    1  12                  1     0     0     0
    2  12                  1     0     0     0
    3  12                  0     1     0     0
    1  14                  0     1     1     0