Search code examples
pythonpandasdataframealgorithmrolling-computation

Identify n consecutive values from a column with the rolling window


I have a dataframe where I have set one column larger than another column.

df['y'] = (df['a'] > df['b']).astype(int) 

from this i got a new column with 1 and 0. Now I would like to get out with the rolling window function all columns, where the 1 9 is repeated consecutively and then the applicable values from the column df['a'] copied into a new column and fill the not applicable rows with nan

   a         b         y    z
-0.143    -0.109       0    nan
-0.109    -0.108       1    nan
-0.118    -0.108       1    nan
-0.118    -0.113       1    nan
-0.090    -0.110       0    nan
-0.090    -0.108       0    nan
-0.075    -0.050       1    -0.075
-0.075    -0.059       1    -0.075 
-0.065    -0.056       1    -0.065
-0.065    -0.053       1    -0.065
-0.042    -0.040       1    -0.042
-0.042    -0.039       1    -0.042
-0.064    -0.060       1    -0.064
-0.064    -0.057       1    -0.064
-0.055    -0.054       1    -0.055
-0.055    -0.053       1    -0.055
-0.068    -0.069       0    nan
-0.068    -0.056       1    nan
-0.074    -0.075       0    nan
-0.063    -0.076       0    nan
-0.074    -0.056       1    nan
-0.063    -0.069       0    nan
-0.077    -0.075       1    nan
-0.077    -0.050       1    nan
-0.082    -0.058       1    nan
-0.127    -0.056       1    nan
-0.095    -0.100       0    nan
-0.095    -0.094       1    nan
-0.108    -0.096       1    nan

Solution

  • IIUC, count the number of consecutive 1s, and set up a mask:

    N = 9
    # identify non-1s
    m1 = df['y'].ne(1)
    # calculate the number of consecutive 1s and check if ≥ N
    m2 = df.groupby(m1.cumsum())['y'].transform('size').gt(N) # or .sub(1).ge(N)
    # copy a if the above matches
    df['z'] = df['a'].where(~m1&m2)
    

    For fun, alternative with rolling:

    N = 9
    m1 = df['y'].rolling(N, min_periods=1).sum().eq(N)
    m2 = m1[::-1].astype(int).rolling(N, min_periods=1).max().eq(1)
    
    df['z'] = df['a'].where(m2)
    

    output:

            a      b  y      z
    0  -0.143 -0.109  0    NaN
    1  -0.109 -0.108  1    NaN
    2  -0.118 -0.108  1    NaN
    3  -0.118 -0.113  1    NaN
    4  -0.090 -0.110  0    NaN
    5  -0.090 -0.108  0    NaN
    6  -0.075 -0.050  1 -0.075
    7  -0.075 -0.059  1 -0.075
    8  -0.065 -0.056  1 -0.065
    9  -0.065 -0.053  1 -0.065
    10 -0.042 -0.040  1 -0.042
    11 -0.042 -0.039  1 -0.042
    12 -0.064 -0.060  1 -0.064
    13 -0.064 -0.057  1 -0.064
    14 -0.055 -0.054  1 -0.055
    15 -0.055 -0.053  1 -0.055
    16 -0.068 -0.069  0    NaN
    17 -0.068 -0.056  1    NaN
    18 -0.074 -0.075  0    NaN
    19 -0.063 -0.076  0    NaN
    20 -0.074 -0.056  1    NaN
    21 -0.063 -0.069  0    NaN
    22 -0.077 -0.075  1    NaN
    23 -0.077 -0.050  1    NaN
    24 -0.082 -0.058  1    NaN
    25 -0.127 -0.056  1    NaN
    26 -0.095 -0.100  0    NaN
    27 -0.095 -0.094  1    NaN
    28 -0.108 -0.096  1    NaN