Search code examples
nlphuggingface-transformersbert-language-model

Fine-tuning BERT with deterministic masking instead of random masking


I want to fine-tune BERT on a specific dataset. My problem is that I do not want to mask some tokens of my training dataset randomly, but I already have chosen which tokens I want to mask (for certain reasons).

To do so, I created a dataset that has two columns: text in which some tokens have been replaced with [MASK] (I am aware of the fact that some words could be tokenised with more than one token and I took care of that) and label where I have the whole text.

Now I want to fine-tune a BERT model (say, bert-base-uncased) using Hugging Face's transformers library, but I do not want to use DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.2) where the masking is done randomly and I only can control the probability. What can I do?


Solution

  • This is what I did to solve my problem. I created a custom class and changed the tokenization in a way that I needed (mask one of the numerical spans in the input).

    class CustomDataCollator(DataCollatorForLanguageModeling):
    
        mlm: bool = True
        return_tensors: str = "pt"
    
        def __post_init__(self):
            if self.mlm and self.tokenizer.mask_token is None:
                raise ValueError(
                    "This tokenizer does not have a mask token which is necessary "
                    "for masked language modeling. You should pass `mlm=False` to "
                    "train on causal language modeling instead."
                )
    
        def torch_mask_tokens(self, inputs, special_tokens_mask):
            """
            Prepare masked tokens inputs/labels for masked language modeling.
            NOTE: keep `special_tokens_mask` as an argument for avoiding error
            """
    
            # labels is batch_size x length of the sequence tensor
            # with the original token id
            # the length of the sequence includes the special tokens (2)
            labels = inputs.clone()
    
            batch_size = inputs.size(0)
            # seq_len = inputs.size(1)
            # in each seq, find the indices of the tokens that represent digits
            dig_ids = [1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023]
            dig_idx = torch.zeros_like(labels)
            for dig_id in dig_ids:
                dig_idx += (labels == dig_id)
            dig_idx = dig_idx.bool()
            # in each seq, find the spans of Trues using `find_spans` function
            spans = []
            for i in range(batch_size):
                spans.append(find_spans(dig_idx[i].tolist()))
            masked_indices = torch.zeros_like(labels)
            # spans is a list of lists of tuples
            # in each tuple, the first element is the start index
            # and the second element is the length
            # in each child list, choose a random tuple
            for i in range(batch_size):
                if len(spans[i]) > 0:
                    idx = torch.randint(0, len(spans[i]), (1,))
                    start, length = spans[i][idx[0]]
                    masked_indices[i, start:start + length] = 1
                else:
                    print("No digit found in the sequence!")
            masked_indices = masked_indices.bool()
    
            # We only compute loss on masked tokens
            labels[~masked_indices] = -100
    
            # change the input's masked_indices to self.tokenizer.mask_token
            inputs[masked_indices] = self.tokenizer.mask_token_id
    
            return inputs, labels
    
    def find_spans(lst):
        spans = []
        for k, g in groupby(enumerate(lst), key=itemgetter(1)):
            if k:
                glist = list(g)
                spans.append((glist[0][0], len(glist)))
    
        return spans