Search code examples
pythonhuggingface-transformershuggingface-tokenizersgpt-2

How to replace the tokenize() and pad_sequence() functions from transformers?


I got the following imports:

import torch, csv, transformers, random
import torch.nn as nn
from torch.utils.data import Dataset
import torch.optim as optim
import pandas as pd
from transformers import GPT2Tokenizer, GPT2LMHeadModel, tokenize, pad_squences

And I'm getting this error:

ImportError                               Traceback (most recent call last)
<ipython-input-35-e04c63220105> in <module>
      4 import torch.optim as optim
      5 import pandas as pd
----> 6 from transformers import GPT2Tokenizer, GPT2LMHeadModel, tokenize, pad_squences

ImportError: cannot import name 'tokenize' from 'transformers' (/usr/local/lib/python3.8/dist-packages/transformers/__init__.py)

This is how I am using the tokenize() and pad_sequence() functions:

class RephraseDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        query, rephrases = self.data[index]
        tokenized_query = tokenizer.encode(query, add_special_tokens=True)
        # tokenized_query = tokenize(self.tokenizer, query)
        padded_query = tokenized_query + [tokenizer.pad_token_id] * (max_length - len(tokenized_query))
        # padded_query = pad_sequences(self.tokenizer, tokenized_query, max_length=128)
        tokenized_rephrases = [tokenize(self.tokenizer, r) for r in rephrases]
        padded_rephrases = [pad_sequences(self.tokenizer, r, max_length=128) for r in tokenized_rephrases]
        return padded_query, padded_rephrases

# Create the dataset
dataset = RephraseDataset(data, tokenizer)

# Create a dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
)

How can I fix this problem? I couldn't find anything in the docs. What version should I roll transformers back to?


Solution

  • [EDIT]

    This happend because transformers version to old for this. Please update transformers with pip install -U transformers