Search code examples
python-3.xpandasdataframepandas-groupby

Pandas groupby and aggregate: produce unique single values for some cells


All,

I have the following pd.DataFrame():

df = pd.DataFrame({'fruit': ['carrot','carrot','apple','apple', 'tomato'],
                   'taste': ['sweet','sweet','sweet','bitter','bitter'],
                   'quality': ['good','poor','rotten','good','good']})

Which looks like this:

    fruit   taste quality
0  carrot   sweet    good
1  carrot   sweet    poor
2   apple   sweet  rotten
3   apple  bitter    good
4  tomato  bitter    good

I apply grouby and agg like this:

df.groupby('fruit').agg(pd.Series.tolist)

producing:

                  taste         quality
fruit                                  
apple   [sweet, bitter]  [rotten, good]
carrot   [sweet, sweet]    [good, poor]
tomato         [bitter]          [good]

But what I want is:

                  taste         quality
fruit                                  
apple   [sweet, bitter]  [rotten, good]
carrot            sweet    [good, poor]
tomato           bitter            good

In words: I want to only aggregate the entries that have multiple different values, but when the list only contains the same unique item, I want it to just contain said item. Is there a nice way of doing this (preferable by not going through all cells of the df, mine is rather big, but perhaps it is the only way?)? Apologies if I'm unclear, I struggle to express this in words (hence the difficult title also).

Thank you in advance.


Solution

  • Use custom lambda function for remove duplicates by sets with convert unique values to scalars:

    f = lambda x: list(set(x)) if len(set(x)) > 1 else x.iat[0]
    df = df.groupby('fruit').agg(f)
    print (df)
                      taste         quality
    fruit                                  
    apple   [sweet, bitter]  [rotten, good]
    carrot            sweet    [poor, good]
    tomato           bitter            good
    

    If ordering is important:

    f = lambda x: list(dict.fromkeys(x)) if len(set(x)) > 1 else x.iat[0]
    df = df.groupby('fruit').agg(f)
    print (df)
                      taste         quality
    fruit                                  
    apple   [sweet, bitter]  [rotten, good]
    carrot            sweet    [good, poor]
    tomato           bitter            good