Search code examples
pythonlistlist-comprehension

Python find and keep patterns in a list and replace others


I am trying to write a code to extract patterns from a each list in a list of lists. I search for patterns with specified lengths, such as 'B-' followed by 'I-'. For example, I want to keep patterns with lengths of two and replace others with a specified string from the following list:

list = ['O', 'B-', 'I-', 'I-', 'O', 'B-', 'I-', 'B-']

Expected output should be as follows:

expected_list_2 = ['O', 'O', 'O', 'O', 'O', 'B-', 'I-', 'O']

As can be seen only length of two pattern 'B-', 'I-' is kept and others changed with 'O' label.

If I want to keep patterns with lengths of three, the output should be as follows:

expected_list_3 = ['O', 'B-', 'I-', 'I-', 'O', 'O', 'O', 'O']

Considering each elements of my list of lists includes such kind of lists and I try to implement this task for each list, I asked for, is there any efficient or tricky way to do this instead of defining some if-else conditions and looping over the each element?


Solution

  • This solution should (please test with more relevant cases before deploying to production) find all positions of pattern 'B-', n-1 x 'I-' in a list. I extended the example list1 to cover more cases, like pattern at start and end of the list and consecutive patterns.

    list1 = ['B-', 'I-', 'I-', 'O', 'B-', 'I-', 'I-', 'B-', 'I-', 'B-', 'I-', 'O', 'B-', 'I-', 'I-']
    #n = 2:                                            ^^^^^^^^^   ^^^^^^^^^
    #n = 3:  ^^^^^^^^^^^^^^^^        ^^^^^^^^^^^^^^^                                ^^^^^^^^^^^^^^^
    
    def find_pattern(list1, n):
        pattern = ['B-'] + ['I-'] * (n-1)
        first = pattern[0]
        
        # find starting indices of matching patterns
        idx = [e for e, i in enumerate(list1[:-n+1 or None])
                 if i == first                   # optimization for long pattern
                and list1[e:e+n] == pattern
                and list1[e+n:e+n+1] != ['I-']]
    
        # insert pattern at those indices
        res = ['O'] * len(list1)
        for i in idx:
            res[i:i+n] = pattern 
        return res
    
    print(find_pattern(list1, 2))
    print(find_pattern(list1, 3))
    

    Output

    ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-', 'I-', 'B-', 'I-', 'O', 'O', 'O', 'O']
    ['B-', 'I-', 'I-', 'O', 'B-', 'I-', 'I-', 'O', 'O', 'O', 'O', 'O', 'B-', 'I-', 'I-']