Search code examples

How can I change the groupby column to find the first row that meets the conditions of a mask if the initial groupby failed to find it?

This is my DataFrame:

import pandas as pd
df = pd.DataFrame(
        'main': ['x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'y', 'y', 'y', 'y', 'y', 'y', 'y'],
        'sub': ['c', 'c', 'c', 'd', 'd', 'e', 'e', 'e', 'e', 'f', 'f', 'f', 'f', 'g', 'g', 'g'],
        'num_1': [10, 9, 80, 80, 99, 101, 110, 222, 90, 1, 7, 10, 2, 10, 95, 10],
        'num_2': [99, 99, 99, 102, 102, 209, 209, 209, 209, 100, 100, 100, 100, 90, 90, 90]

And This is my expected output. I want to add column result:

   main sub  num_1  num_2  result
0     x   c     10     99     101
1     x   c      9     99     101
2     x   c     80     99     101
3     x   d     80    102     110
4     x   d     99    102     110
5     x   e    101    209     222
6     x   e    110    209     222
7     x   e    222    209     222
8     x   e     90    209     222
9     y   f      1    100     NaN
10    y   f      7    100     NaN
11    y   f     10    100     NaN
12    y   f      2    100     NaN
13    y   g     10     90      95
14    y   g     95     90      95
15    y   g     10     90      95

The mask is:

mask = (df.num_1 > df.num_2)

The process starts like this:

a) The groupby column is sub

b) Finding the first row that meets the condition of the mask for each group.

c) Put the value of num_1 in the result

If there are no rows that meets the condition of the mask, then the groupby column is changed to main to find the first row of mask. There is condition for this phase:

The previous subs should not be considered when using main as the groupby column.

An example of the above steps for group d in the sub column:

a) sub is the groupby column.

b) There are no rows in the d group that df.num_1 > df.num_2

So now for group d, its main group is searched. However group c is also in this main group. Since it is before group d, group c should not count for this step.

In this image I have shown where those values come from:

enter image description here

And this is my attempt. It partially solves the issue for some groups but not all of them:

def step_a(g):
    mask = (g.num_1 > g.num_2)

    g.loc[mask.cumsum().eq(1) & mask, 'result'] = g.num_1
    g['result'] = g.result.ffill().bfill()
    return g

a = df.groupby('sub').apply(step_a)


  • IIUC, you can use broadcasting to form a mask per "main" and use this to find the first num1>num2 while considering only the next groups:

    def find(g):
        # get sub as 0,1,2…
        sub = pd.factorize(g['sub'])[0]
        # convert inputs to numpy
        n1 = g['num_1'].to_numpy()
        n2 = g.loc[~g['sub'].duplicated(), 'num_2'].to_numpy()
        # form mask
        # (n1[:, None] > n2) -> num_1 > num_2
        # (sub[:, None] >= np.arange(len(n2))) -> exclude previous groups
        m = (n1[:, None] > n2) & (sub[:, None] >= np.arange(len(n2)))
        # find first True per column
        return pd.Series(np.where(m.any(0), n1[m.argmax(0)], np.nan)[sub],
    df['result'] = df.groupby('main', group_keys=False).apply(find)

    Note that you can easily tweak the masks to perform other logics (search in next n groups, exclude all previous groups except the immediate previous one, etc.).


    # example 1                               # example 2 (from comments)
       main sub  num_1  num_2  result            main sub  num_1  num_2  result
    0     x   c     10     99   101.0         0     x   d     10    102   110.0
    1     x   c      9     99   101.0         1     x   d      9    102   110.0
    2     x   c     80     99   101.0         2     x   c     80     99   101.0
    3     x   d     80    102   110.0         3     x   c     80     99   101.0
    4     x   d     99    102   110.0         4     x   c     99     99   101.0
    5     x   e    101    209   222.0         5     x   e    101    209   222.0
    6     x   e    110    209   222.0         6     x   e    110    209   222.0
    7     x   e    222    209   222.0         7     x   e    222    209   222.0
    8     x   e     90    209   222.0         8     x   e     90    209   222.0
    9     y   f      1    100     NaN         9     y   f      1    100     NaN
    10    y   f      7    100     NaN         10    y   f      7    100     NaN
    11    y   f     10    100     NaN         11    y   f     10    100     NaN
    12    y   f      2    100     NaN         12    y   f      2    100     NaN
    13    y   g     10     90    95.0         13    y   g     10     90    95.0
    14    y   g     95     90    95.0         14    y   g     95     90    95.0
    15    y   g     10     90    95.0         15    y   g     10     90    95.0

    Intermediate masks m, here for the second example:

    # main == 'x'
    #          99    102    209
    array([[False, False, False],  #  10
           [False, False, False],  #   9
           [False, False, False],  #  80
           [False, False, False],  #  80
           [False, False, False],  #  99
           [False,  True, False],  # 101
           [ True,  True, False],  # 110
           [ True,  True,  True],  # 222
           [False, False, False]]) #  90
    # out:    110    101    222
    # main == 'y'
    #         100     90
    array([[False, False],  #   1
           [False, False],  #   7
           [False, False],  #  10
           [False, False],  #   1
           [False, False],  #  10
           [False,  True],  #  95
           [False, False]]) #  90
    # out:    NaN     95