Search code examples
pythonpandas

Rolling window selection with groupby in pandas


I have the following pandas dataframe:

# Create the DataFrame
df = pd.DataFrame({
    'id': [1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2],
    'date': [1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, 9, 10, 11, 12],
    'value': [11, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28]
})
df

id  date    value
0   1   1   11
1   1   2   12
2   1   3   13
3   1   4   14
4   1   5   15
5   1   6   16
6   1   7   17
7   1   8   18
8   2   5   21
9   2   6   22
10  2   7   23
11  2   8   24
12  2   9   25
13  2   10  26
14  2   11  27
15  2   12  28

I want to query the above dataframe, in a rolling window manner, for both ids. The rolling window should be of size n.

So, if n==2, in the 1st iteration I would like to query this:

df.query('(id==1 and (date==1 or date==2)) or (id==2 and (date==5 or date==6))')

id  date    value
0   1   1   11
1   1   2   12
8   2   5   21
9   2   6   22

in the 2nd iteration I would like to query this:

df.query('(id==1 and (date==2 or date==3)) or (id==2 and (date==6 or date==7))')

id  date    value
1   1   2   12
2   1   3   13
9   2   6   22
10  2   7   23

in the 3rd iteration I would like to query this:

df.query('(id==1 and (date==3 or date==4)) or (id==2 and (date==7 or date==8))')

id  date    value
2   1   3   13
3   1   4   14
10  2   7   23
11  2   8   24

etc. How could I do that in pandas ? My data has around 500 ids


Solution

  • The exact expected logic is not fully clear, but assuming you want to loop over the groups/rolls, you could combine groupby.nth with sliding_window_view. By reusing the DataFrameGroupBy object, will only need to compute the groups once:

    import numpy as np
    from numpy.lib.stride_tricks import sliding_window_view as swv
    
    n = 2
    
    max_size = df['id'].value_counts(sort=False).max()
    g = df.sort_values(by=['id', 'date']).groupby('id', sort=False)
    
    for idx in swv(np.arange(max_size), n):
        print(f'rows {idx}')
        print(g.nth(idx))
    

    Output:

    rows [0 1]
       id  date  value
    0   1     1     11
    1   1     2     12
    8   2     5     21
    9   2     6     22
    rows [1 2]
        id  date  value
    1    1     2     12
    2    1     3     13
    9    2     6     22
    10   2     7     23
    rows [2 3]
        id  date  value
    2    1     3     13
    3    1     4     14
    10   2     7     23
    11   2     8     24
    rows [3 4]
        id  date  value
    3    1     4     14
    4    1     5     15
    11   2     8     24
    12   2     9     25
    rows [4 5]
        id  date  value
    4    1     5     15
    5    1     6     16
    12   2     9     25
    13   2    10     26
    rows [5 6]
        id  date  value
    5    1     6     16
    6    1     7     17
    13   2    10     26
    14   2    11     27
    rows [6 7]
        id  date  value
    6    1     7     17
    7    1     8     18
    14   2    11     27
    15   2    12     28
    

    Alternatively, and assuming groups with an identical size and sorted by id/date, using shifted indexing:

    n = 2
    ngroups = df['id'].nunique()
    for idx in swv(np.arange(len(df)).reshape(-1, ngroups, order='F'), n, axis=0):
        print(f'indices: {idx.ravel()}')
        print(df.iloc[idx.flat])
    

    Output:

    indices: [0 1 8 9]
       id  date  value
    0   1     1     11
    1   1     2     12
    8   2     5     21
    9   2     6     22
    indices: [ 1  2  9 10]
        id  date  value
    1    1     2     12
    2    1     3     13
    9    2     6     22
    10   2     7     23
    indices: [ 2  3 10 11]
        id  date  value
    2    1     3     13
    3    1     4     14
    10   2     7     23
    11   2     8     24
    indices: [ 3  4 11 12]
        id  date  value
    3    1     4     14
    4    1     5     15
    11   2     8     24
    12   2     9     25
    indices: [ 4  5 12 13]
        id  date  value
    4    1     5     15
    5    1     6     16
    12   2     9     25
    13   2    10     26
    indices: [ 5  6 13 14]
        id  date  value
    5    1     6     16
    6    1     7     17
    13   2    10     26
    14   2    11     27
    indices: [ 6  7 14 15]
        id  date  value
    6    1     7     17
    7    1     8     18
    14   2    11     27
    15   2    12     28