Search code examples
pythonpytorchtorchtext

torchtext BucketIterator minimum padding


I'm trying to use the BucketIterator.splits function in torchtext to load data from csv files for use in a CNN. Everything works fine unless I have a batch that the longest sentence is shorter than the biggest filter size.

In my example I have filters of sizes 3, 4, and 5 so if the longest sentence doesn't have at least 5 words I get an error. Is there a way to let the BucketIterator dynamically set the padding for batches, but also set a minimum padding length?

This is my the code I am using for my BucketIterator:

train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text), batch_size=batch_size, repeat=False, device=device)

I'm hoping there is a way to set a minimum length on the sort_key or something like that?

I tried this but it doesn't work:

FILTER_SIZES = [3,4,5]
train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text) if len(x.text) >= FILTER_SIZES[-1] else FILTER_SIZES[-1], batch_size=batch_size, repeat=False, device=device) 

Solution

  • I looked through the torchtext source code to better understand what the sort_key was doing, and saw why my original idea wouldn't work.

    I'm not sure if it is the best solution or not, but I have come up with a solution that works. I created a tokenizer function that pads the text if it is shorter than the longest filter length, then create the BucketIterator from there.

    FILTER_SIZES = [3,4,5]
    spacy_en = spacy.load('en')
    
    def tokenizer(text):
        token = [t.text for t in spacy_en.tokenizer(text)]
        if len(token) < FILTER_SIZES[-1]:
            for i in range(0, FILTER_SIZES[-1] - len(token)):
                token.append('<PAD>')
        return token
    
    TEXT = Field(sequential=True, tokenize=tokenizer, lower=True, tensor_type=torch.cuda.LongTensor)
    
    train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text), batch_size=batch_size, repeat=False, device=device)