I am learning about Bert, which only deals with texts with fewer than 512 tokens, and came across this answer which says that truncating text in the middle (as opposed to at the start or at the end) may work well for Bert. I wonder whether there is any library to do that type of truncation because as far as I understand, one word can consist of multiple Bert token so I cannot simply get the middle 512 words. Thanks in advance
The post references a paper which says that the first 128 tokens and the last 382 tokens (not including the CLS and SEP tokens) should be kept.
For tokenization, you can use the Bert Tokenizer from HuggingFace's Transformers library to tokenize the full String and then trim out everything besides the first 129 and last 383 tokens. 129 because we include the initial CLS token and 383 because we include the ending SEP token.
Code:
# Generate sample text
from string import ascii_lowercase
sample_text = ""
for c1 in ascii_lowercase:
for c2 in ascii_lowercase:
sample_text += f"{c1}{c2} "
# Get tokenizer
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# Perform tokenization
tokenized = tokenizer(sample_text)
# Trim tokens
if len(tokenized["input_ids"]) > 512:
for k,v in tokenized.items():
tokenized[k] = v[:129] + v[-383:]
# Verify result
print(tokenizer.decode(tokenized["input_ids"]))
print(len(tokenized["input_ids"]))