Search code examples
word-embeddinghuggingface-transformershuggingface-tokenizers

How to encode multiple sentences using transformers.BertTokenizer?


I would like to create a minibatch by encoding multiple sentences using transform.BertTokenizer. It seems working for a single sentence. How to make it work for several sentences?

from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# tokenize a single sentence seems working
tokenizer.encode('this is the first sentence')
>>> [2023, 2003, 1996, 2034, 6251]

# tokenize two sentences
tokenizer.encode(['this is the first sentence', 'another sentence'])
>>> [100, 100] # expecting 7 tokens

Solution

  • transformers >= 4.0.0:
    Use __call__ method of the tokenizer. It will generate a dictionary which contains the input_ids, token_type_ids and the attention_mask as list for each input sentence:

    tokenizer(['this is the first sentence', 'another setence'])
    

    Output:

    {'input_ids': [[101, 2023, 2003, 1996, 2034, 6251, 102], [101, 2178, 2275, 10127, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]}
    

    transformers < 4.0.0:
    Use tokenizer.batch_encode_plus (documentation). It will generate a dictionary which contains the input_ids, token_type_ids and the attention_mask as list for each input sentence:

    tokenizer.batch_encode_plus(['this is the first sentence', 'another setence'])
    

    Output:

    {'input_ids': [[101, 2023, 2003, 1996, 2034, 6251, 102], [101, 2178, 2275, 10127, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]}
    

    Applies to call and batch_encode_plus:
    In case you only want to generate the input_ids, you have to set return_token_type_ids and return_attention_mask to False:

    tokenizer.batch_encode_plus(['this is the first sentence', 'another setence'], return_token_type_ids=False, return_attention_mask=False)
    

    Output:

    {'input_ids': [[101, 2023, 2003, 1996, 2034, 6251, 102], [101, 2178, 2275, 10127, 102]]}