Search code examples
pythonhuggingface-transformershuggingface-tokenizershuggingface-datasets

How to drop sentences that are too long in Huggingface?


I'm going through the Huggingface tutorial and it appears as the library has automatic truncation, to cut sentences that are too long, based on a max value, or other things.

How can I remove sentences for the same reasoning (sentences are too long, based on a max value, etc), instead of truncating them? e.g., if the sentence is too long, drop it.

Example for truncation:

from transformers import AutoTokenizer

checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
sentence_input = 'this is an input'

result = tokenizer(sentence_input, padding=True, truncation=True, return_tensors="pt")

Example to prepare samples in a batch

from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding

raw_datasets = load_dataset("glue", "mrpc")
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)


def tokenize_function(example):
    return tokenizer(example["sentence1"], example["sentence2"], truncation=True)


tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Solution

  • A filter is all you need:

    import pandas
    from datasets import Dataset
    from transformers import AutoTokenizer
    
    df = pandas.DataFrame([{"sentence1": "bla", "sentence2": "bla"}, {"sentence1": "bla "*600, "sentence2": "bla"}])
    dataset = Dataset.from_pandas(df)
    
    
    checkpoint = "bert-base-uncased"
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    
    #Not truncating the samples allows us to filter them
    def tokenize_function(example):
        return tokenizer(example["sentence1"], example["sentence2"])
    
    
    tokenized_datasets = dataset.map(tokenize_function, batched=True)
    print(len(tokenized_datasets))
    tokenized_datasets = tokenized_datasets.filter(lambda example: len(example['input_ids']) <= tokenizer.max_model_input_sizes[checkpoint])
    print(len(tokenized_datasets))
    
    

    Output:

    Token indices sequence length is longer than the specified maximum sequence length for this model (1205 > 512). Running this sequence through the model will result in indexing errors
    2
    1