Search code examples
pythonpandasgreatest-n-per-group

use group by to get n smallest values but with duplicates


Suppose I have pandas DataFrame like this:

>>> df = pd.DataFrame({'id':[1,1,1,1,1,2,2,2,2,2,2,3,4],'value':[1,1,1,1,3,1,2,2,3,3,4,1,1]})
>>> df
id  value
1      1
1      1
1      1
1      1
1      3
2      1
2      2
2      2
2      3
2      3
2      4
3      1
4      1

I want to get a new DataFrame with top 2 (well really n values) values for each id including duplicates, like this:

   id  value
0   1      1
1   1      1
3   1      1
4   1      1
5   1      3
6   2      1
7   2      2
8   2      2
9   3      1
10  4      1

I've tried using head() and nsmallest() but I think those will not include duplicates. Is there a better way to do this?

Edited to make it clear I want more than 2 records per group if there are more than 2 duplictes


Solution

  • Use DataFrame.drop_duplicates in first step, then get top values and last use DataFrame.merge:

    df1 = df.drop_duplicates(['id','value']).sort_values(['id','value']).groupby('id').head(2)
    df = df.merge(df1)
    print (df)
       id  value
    0   1      1
    1   1      1
    2   1      2
    3   1      2
    4   2      1
    5   2      2
    6   2      2
    7   3      1
    8   4      1
    

    df = pd.DataFrame({'id':[1,1,1,1,1,2,2,2,2,2,2,3,4],'value':[1,1,1,1,3,1,2,2,3,3,4,1,1]})
        
    df1 = df.drop_duplicates(['id','value']).sort_values(['id','value']).groupby('id').head(2)
    df = df.merge(df1)
    print (df)
       id  value
    0   1      1
    1   1      1
    2   1      1
    3   1      1
    4   1      3
    5   2      1
    6   2      2
    7   2      2
    8   3      1
    9   4      1
    

    Or use custom lambda function with GroupBy.transform and filter in boolean indexing:

    df = df[df.groupby('id')['value'].transform(lambda x: x.isin(sorted(set(x))[:2]))]
    print (df)
        id  value
    0    1      1
    1    1      1
    2    1      2
    3    1      2
    5    2      1
    6    2      2
    7    2      2
    11   3      1
    12   4      1
    

    df = df[df.groupby('id')['value'].transform(lambda x: x.isin(sorted(set(x))[:2]))]
    print (df)
        id  value
    0    1      1
    1    1      1
    2    1      1
    3    1      1
    4    1      3
    5    2      1
    6    2      2
    7    2      2
    11   3      1
    12   4      1