Search code examples
pythonpandasnumpyvectorization

Count Instances Until Value Repeats


I have a list/array/series of booleans, and I want a list of the same length that contains NaNs wherever there was a False value in the original, and on True indices I want to plug the number of False elements until the next True. If there is no subsequent True than the last "count" should be np.inf.
For example:

input: [False, False, True, True, False, False, True, False, True, False , False, False]
output: [NaN, NaN, 0, 2, NaN, NaN, 1, NaN, np.inf, NaN, NaN, NaN]

Of course, I could iterate over the array/list and count (better to start the iteration from the end in that case). I'm interested in a vectorized implementation.


Solution

  • Assuming a array:

    a = np.array([False, False, True, True, False, False, True, False, True, False, False, False])
    
    # identify True
    idx = np.nonzero(a)[0]
    # array([2, 3, 6, 8])
    
    # set up output
    out = np.full_like(a, np.nan, dtype=float)
    
    # replace first n-1 True with distance to next
    out[idx[:-1]] = np.diff(idx)-1
    # last True (if any) is np.inf
    out[idx[-1:]] = np.inf
    

    Output:

    array([nan, nan,  0.,  2., nan, nan,  1., nan, inf, nan, nan, nan])
    

    Assuming a Series with a range index:

    s = pd.Series([False, False, True, True, False, False, True, False, True, False , False, False])
    
    out = pd.Series(index=s.index)
    out[idx] = (-s[s].index.diff(-1)-1).fillna(np.inf)
    

    Or more complex approach with a groupby

    s = pd.Series([False, False, True, True, False, False, True, False, True, False , False, False])
    
    # reverse Series
    s2 = s[::-1]
    # mask last True/Falses
    m2 = s2.shift(fill_value=False).cummax()
    
    # compute distance to next True
    out = s2.groupby(s2.cumsum()).cumcount().shift()[::-1].where(s2)
    # mask last True with np.inf
    out[s2 & ~ m2] = np.inf
    

    Output:

    0     NaN
    1     NaN
    2     0.0
    3     2.0
    4     NaN
    5     NaN
    6     1.0
    7     NaN
    8     inf
    9     NaN
    10    NaN
    11    NaN
    dtype: float64