Search code examples
pythonspacytorchtext

ReversibleField fails when using spacy custom tokenizer


ReversibleField works well without spacy

When using tokenize=None in the ReversibleField, everything works fine

from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator, ReversibleField
import spacy

SRC = ReversibleField(tokenize=None,
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True,
            batch_first= True)

TRG = ReversibleField(tokenize=None,
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True,
            batch_first= True)
train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), 
                                                    fields = (SRC, TRG))
SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)

device = 'cuda:2'

BATCH_SIZE = 3

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    device = device)

batch = next(iter(train_iterator))
TRG.reverse(batch.trg)

output>>>
['a group of kids playing with tires.',
 'seven construction workers working on a building.',
 'a man is performing with fire sticks before a crowd outside.']

ReversibleField fails when using spacy

However, when I try to use spacy as my tokenizer, it gives me a long chunk of string that doesn't make sense to me.

spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

def tokenize_de(text):
    """
    Tokenizes German text from a string into a list of strings (tokens) and reverses it
    """
    return [tok.text for tok in spacy_de.tokenizer(text)][::-1]

def tokenize_en(text):
    """
    Tokenizes English text from a string into a list of strings (tokens)
    """
    return [tok.text for tok in spacy_en.tokenizer(text)]

SRC = ReversibleField(tokenize = tokenize_de, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True,
            batch_first= True)

TRG = ReversibleField(tokenize = tokenize_en, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True,
            batch_first= True)

train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), 
                                                    fields = (SRC, TRG))

SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    device = device)

batch = next(iter(train_iterator))
TRG.reverse(batch.trg)

output >>>
['agroupofkidsplayingwithtires.',
 'sevenconstructionworkersworkingonabuilding.',
 'amanisperformingwithfiresticksbeforeacrowdoutside.']

What is wrong here? How to convert tokens back to strings correctly when using spacy?


Solution

  • There is an obvious error in ReversibleField definition:

    class ReversibleField(Field):
        def __init__(self, **kwargs):
            warnings.warn('{} class will be retired in the 0.8.0 release and moved to torchtext.legacy. Please see 0.7.0 release notes for further information.'.format(self.__class__.__name__), UserWarning)
            if kwargs.get('tokenize') is list:
                self.use_revtok = False
            else:
                self.use_revtok = True
    
    ...
    
        def reverse(self, batch):
                if self.use_revtok:
                    try:
                        import revtok
                    except ImportError:
                        print("Please install revtok.")
                        raise
       ...
                if self.use_revtok:
                    return [revtok.detokenize(ex) for ex in batch]
    

    You see unless you provide tokenize kwarg as a list reverse is always returned as detokenize on empty revtok tokenizer.

    1. Comment the last 2 lines in the code block above (class definition located in /home/USER/anaconda3/envs/ENV_NAME/lib/python3.7/site-packages/torchtext-0.8.0a0+db31b5d-py3.7-linux-x86_64.egg/torchtext/data/field.py, lines 408-409)
    2. Change your tokenizer to include empty spaces like in code block below

    and you're fine to go.

    Proof:

    from torchtext.datasets import Multi30k
    from torchtext.data import Field, BucketIterator, ReversibleField
    import spacy
    
    # spacy download en_core_web_sm
    # spacy download de_core_news_sm
    
    nlp_en = spacy.load("en_core_web_sm")
    nlp_de = spacy.load("de_core_news_sm")
    
    def tokenize_de(text):
        return [el for els in [(tok.text, tok.whitespace_) for tok in nlp_de(text)] for el in els]
    
    def tokenize_en(text):
        return [el for els in [(tok.text, tok.whitespace_) for tok in nlp_en(text)] for el in els]
    
    
    SRC = ReversibleField(tokenize = tokenize_de,
                            init_token = '<sos>', 
                            eos_token = '<eos>',
                            unk_token='<unk>',
                            lower = True,
                            batch_first= True)
    
    TRG = ReversibleField(tokenize = tokenize_en,
                            init_token = '<sos>',
                            eos_token = '<eos>', 
                            unk_token='<unk>',
                            lower = True,
                            batch_first= True)
    
    train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), 
                                                        fields = (SRC, TRG))
    
    SRC.build_vocab(train_data, min_freq = 3)
    TRG.build_vocab(train_data, min_freq = 3)
    
    train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
        (train_data, valid_data, test_data), 
        batch_size = 3, device="cuda:0")
    
    batch = next(iter(train_iterator))
    TRG.reverse(batch.trg)
    
    ['asian people wearing helmet waiting to buy food.',
     'a mother stands in a kitchen holding a small baby.',
     'a person performing a <unk> bicycle jump over dirt ramps.']