Search code examples

get the name of the group inside pandas groupby transform

Here is what I am trying to do. I have the following DataFrame in pandas:

import numpy as np
import pandas as pd

n_cols = 3
n_samples = 4
df = pd.DataFrame(np.arange(n_samples * n_cols).reshape(n_samples, n_cols), columns=list('ABC'))


   A   B   C
0  0   1   2
1  3   4   5
2  6   7   8
3  9  10  11

I have a category to which each sample (row) belongs:

cat = pd.Series([1,1,2,2])

And I have a reference row related to each category:

df_ref = pd.DataFrame(np.zeros((2, n_cols)), index=[1,2], columns=list('ABC'))
df_ref.loc[1] = 10


      A     B     C
1  10.0  10.0  10.0
2   0.0   0.0   0.0

How do I do the following in a more elegant way (e.g., using groupby and transform):

result = df.copy()
for i in range(n_cols):
  result.iloc[i] = df.iloc[i] - df_ref.loc[cat[i]]


    A   B   C
0 -10  -9  -8
1  -7  -6  -5
2   6   7   8
3   9  10  11

I thought something like this should work:

df.groupby(cat).transform(lambda x: x - df_ref.loc[x.GROUP_NAME])

where x.GROUP_NAME is accessing the name of the group on which transform is operating. In the pandas documentation about transform it is written: "Each group is endowed the attribute ‘name’ in case you need to know which group you are working on." I tried to access, but that gives the name of a column, not the name of the group. So I don't understand what this documentation is referring to.


  • No need for grouby, just reindex df_ref and convert to array:

    df -= df_ref.reindex(cat).values

    Or, for a copy:

    out = df.sub(df_ref.reindex(cat).values)

    Note that your approach would work with groupby.apply:

    out = df.groupby(cat, group_keys=False).apply(lambda x: x - df_ref.loc[])


          A     B     C
    0 -10.0  -9.0  -8.0
    1  -7.0  -6.0  -5.0
    2   6.0   7.0   8.0
    3   9.0  10.0  11.0