Search code examples
pythonpandasdataframevectorization

In pandas df find if the True value in column A is his first occurrence since last True in column B


I'm searching for the most efficient way to find if True value in column A is the first occurrence since last True value in column B.

In this examples the expected output would be column C.

Example 1:

df = pd.DataFrame({
    'A': [False, False, True, False, True, False, True, False, True],
    'B': [True, False, False, False, False, True, False, False, False],
    'C': [False, False, True, False, False, False, True, False, False]
})
A B C
0 False True False
1 False False False
2 True False True
3 False False False
4 True False False
5 False True False
6 True False True
7 False False False
8 True False False

Example 2:

df = pd.DataFrame({
    'A': [True, False, False, True, False, True, False, True, False],
    'B': [False, True, False, False, False, False, True, False, False],
    'C': [False, False, False, True, False, False, False, True, False]
})
A B C
0 True False False
1 False True False
2 False False False
3 True False True
4 False False False
5 True False False
6 False True False
7 True False True
8 False False False

Example 3:

Here you can find a .csv file with a bigger example


Solution

  • You can use a groupby operation on the cumulative sum of column "B" to group your dataframe how you described. Then you can use idxmax to get the index where each of those first occurrences exist within column "A". Once you have those indices, you can create your new column "C".

    Using idxmax is a little trick because we're not actually interested in the maximum value since column "A" only ever has True and False as its values. idxmax will return the index of the first occurrence of the maximum (in this case, the first occurrence of True within each group), which is what we're specifically interested in.

    df = pd.DataFrame({
        'A': [False, False, True, False, True, False, True, False, True],
        'B': [True, False, False, False, False, True, False, False, False],
    })
    
    # get a dataframe of the position of the max as well as the max value
    indices_df = df["A"].groupby(df["B"].cumsum()).agg(["idxmax", "max"])
    
    # mask to filter out the 0th group
    skip_0th = (indices_df.index > 0)
    
    # mask to filter out groups who do not have True as a value
    groups_with_true = (indices_df["max"] == True)
    
    # combine masks and retrieve the appropriate index
    indices = indices_df.loc[skip_0th & groups_with_true, "idxmax"]
    
    df["C"] = False
    df.loc[indices, "C"] = True
    
    print(df)
           A      B      C
    0  False   True  False
    1  False  False  False
    2   True  False   True
    3  False  False  False
    4   True  False  False
    5  False   True  False
    6   True  False   True
    7  False  False  False
    8   True  False  False
    

    Updated for example 2.

    We can resolve this issue by slicing our indices Series to exclude any entry whose index is 0 (e.g. label slicing from 1 to the end). This works because of our groupby operation assigns integer based labels according to the .cumsum. In example 1, the smallest index label will be 1 (since the first value in column "B" is True). Whereas in example 2, the smallest index label will be 0. Since we don't want the 0 to effect our results, we can simply slice it away from our indices.

    When we assign "C" after performing the slicing on our indices Series, we will appropriately ignore all of the values from before the first occurrence of True in column "B".

    Enough text though, lets see some code.

    Example 1

    print(indices)
    1    2
    2    6
    
    # Slicing here doesn't change anything, since indices does not have
    #  a value corresponding to label position 0
    indices = indices.loc[1:]
    print(indices)
    1    2
    2    6
    

    Example 2

    print(indices)
    0    0
    1    3
    2    7
    
    # we don't want to include the value from label position 0 in `indices`
    #  so we can use slicing to remove it
    
    indices = indices.loc[1:]
    print(indices)
    1    3
    2    7