Search code examples
pythonpytorchtensorbert-language-modelfairseq

generating segment labels for a Tensor given a value indicating segment boundaries


Does anyone know of a way to generate a 'segment label' for a Tensor, given a unique value that represents segment boundaries within the Tensor?

For example, given a 1D input tensor where the value 1 represents a segment boundary,

x = torch.Tensor([5, 4, 1, 3, 6, 2])

the resulting segment label Tensor should have the same shape with values representing the two segments:

segment_label = torch.Tensor([1, 1, 1, 2, 2, 2])

Likewise, for a batch of inputs, e.g. batch size = 3,

x = torch.Tensor([
    [5, 4, 1, 3, 6, 2],
    [9, 4, 5, 1, 8, 10],
    [10, 1, 5, 4, 8, 9]
    ])

the resulting segment label Tensor (using 1 as the segment separator) should look something like this:

segment_label = torch.Tensor([
    [1, 1, 1, 2, 2, 2],
    [1, 1, 1, 1, 2, 2],
    [1, 1, 2, 2, 2, 2]
    ])

Context: I'm currently working with Fairseq's Transformer implementation in PyTorch for a seq2seq NLP task. I am looking for a way to incorporate BERT-like segment embeddings in Transformer during the encoder's forward pass, rather than modifying an exisiting dataset used for translation tasks such as language_pair_dataset.

Thanks in advance!


Solution

  • You can use torch.cumsum to pull the trick:

    mask = (x == 1).to(x)  # mask with only the boundaries
    segment_label = mask.cumsum(dim=-1) - mask + 1
    

    Results with the desired segment_label.