Search code examples
pythonpandaslistpandas-groupby

pandas grouped list aggregation using transform fails with key error


How can I apply a list function to a grouped pandas dataframe which is not aggregated using apply but rather transform?

For me the following fails with: KeyError: "None of [Index(['v1', 'v2'], dtype='object')] are in the [index]"

import pandas as pd
df = pd.DataFrame({'key':[1,1,1,2,3,2], 'v1': [1,4,6,7,4,9], 'v2':[0.3, 0.6, 0.4, .1, .2, .8]})
display(df)

def list_function(x):
    #display(x)
    all_values = x[['v1','v2']].drop_duplicates()
    #display(all_values)
    #result = all_values.to_json()
    result = all_values.values
    return result


display(df.groupby(['key']).apply(list_function))
df['list_result'] = df.groupby(['key']).transform(list_function)
df

NOTICE: I know that a join would be possible with the aggregated data, but in this particular case I would prefer not having to do the JOIN afterwards.


Solution

  • It is not possible, in pandas GroupBy.transform and also GroupBy.agg working with each column separately, so cannot select by multiple columns names like you need.

    It is possible only by GroupBy.apply.

    So instead transform is possible use Series.map if match one column, for multiple column use DataFrame.join:

    df['list_result'] = df['key'].map(df.groupby(['key']).apply(list_function))
    print (df)
    
       key  v1   v2                           list_result
    0    1   1  0.3  [[1.0, 0.3], [4.0, 0.6], [6.0, 0.4]]
    1    1   4  0.6  [[1.0, 0.3], [4.0, 0.6], [6.0, 0.4]]
    2    1   6  0.4  [[1.0, 0.3], [4.0, 0.6], [6.0, 0.4]]
    3    2   7  0.1              [[7.0, 0.1], [9.0, 0.8]]
    4    3   4  0.2                          [[4.0, 0.2]]
    5    2   9  0.8              [[7.0, 0.1], [9.0, 0.8]]
    

    #added one column for match by 2 columns sample
    df['new'] = 1
    
    s = df.groupby(['key', 'new']).apply(list_function)
    df = df.join(s.rename('list_result'), on=['key','new'])
    print (df)
       key  v1   v2  new                           list_result
    0    1   1  0.3    1  [[1.0, 0.3], [4.0, 0.6], [6.0, 0.4]]
    1    1   4  0.6    1  [[1.0, 0.3], [4.0, 0.6], [6.0, 0.4]]
    2    1   6  0.4    1  [[1.0, 0.3], [4.0, 0.6], [6.0, 0.4]]
    3    2   7  0.1    1              [[7.0, 0.1], [9.0, 0.8]]
    4    3   4  0.2    1                          [[4.0, 0.2]]
    5    2   9  0.8    1              [[7.0, 0.1], [9.0, 0.8]]