Search code examples
pythonnumpypython-itertoolsindices

Incrementally adds one to a list of indices given positions of specified item in list


Given a list of tokens, input:

>>> tokenized_text = "[CLS] my dog is cute [SEP] he likes slack ##ing [SEP]".split()
>>> tokenized_text 
['[CLS]', 'my', 'dog', 'is', 'cute', '[SEP]', 'he', 'likes', 'slack', '##ing', '[SEP]']

The goal is to create an index for up till every [SEP] from left to right, find the [SEP] tokens and then incrementally adds the 1 after every [SEP], so the desired output indices for the tokenize_text list above is:

[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]

I've tried:

# Find the indices of `[SEP]`.
>>> sep_indices = np.array(np.where(np.array(tokenized_text) == "[SEP]"))[0]
>>> sep_indices
array([ 5, 10])

>>> prev = 0
>>> out =[]
>>> for i, idx in enumerate(sep_indices):
...     for _ in range(idx-prev):
...         out.append(i)
...     prev = idx
... 
>>> out = [0] + out[:-1]
>>> out
[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]

But is there an easier way to achieve the correct output?


Solution

  • Easier and vectorized way with NumPy -

    In [116]: a = np.asarray(tokenized_text)
    
    In [117]: m = a == "[SEP]"
    
    In [118]: m.cumsum()-m
    Out[118]: array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1])