Search code examples
pythonpandaschainingmethod-chaining

How to Select First N Key-ordered Values of column within a grouping variable in Pandas DataFrame


I have a dataset:

import pandas as pd

data = [
    ('A', 'X'),
    ('A', 'X'),
    ('A', 'Y'),
    ('A', 'Z'),
    ('B', 1),
    ('B', 1),
    ('B', 2),
    ('B', 2),
    ('B', 3),
    ('B', 3),
    ('C', 'L-7'),
    ('C', 'L-9'),
    ('C', 'L-9'),
    ('T', 2020),
    ('T', 2020),
    ('T', 2025)
]

df = pd.DataFrame(data, columns=['ID', 'SEQ'])
print(df)

I want to create a key grouping ID and SEQ in order to select the first 2 rows of each different SEQ within each ID Group

For instance the ID A, has 3 distinct keys "A X", "A Y" and "A Z" in the order of the dataset the first two keys are "A X" and "A Y" thus I must select the first two rows (if available) of each thus

"A X", "A X", "A Y" why? because "A Z" is another key.

I've tried using the groupby and head functions, but I couldn't find a way to achieve this specific result. What can I try next?

(df
.groupby(['ID','SEQ'])
.head(2)
)

This code is returning the original dataset and I wonder if I can solve this problem using method chaining, as it is my preferred style in Pandas.

The final correct output is:

enter image description here


Solution

  • Your approach of using groupby and then head(2) is on the right track for getting the first 2 rows of each different SEQ within each ID group.

    However, the additional requirement is to get only the first 2 unique SEQ groups within each ID. To achieve this, you can:

    Create a new column that has the rank of unique SEQ within each ID group. Use this rank to filter out the data. Finally, use your original approach to get the first 2 rows of each SEQ within each ID group. Here's a solution using method chaining:

    result = (df
              .assign(rank=df.groupby('ID')['SEQ'].transform(lambda x: x.rank(method='dense')))
              .query('rank <= 2')
              .groupby(['ID', 'SEQ'])
              .head(2)
              .drop(columns=['rank'])
             )
    
    print(result)
    

    This should give you the desired output.