Search code examples
pythonhuggingface-transformerstransformer-model

how to resolve Transformer model DistilBert error got an unexpected keyword argument 'special_tokens_mask'


I am using

Apple Mac M1

OS: MacOS Monterey

Python 3.10.4

I am trying to implement a vector search with DistilBERT and Weaviate by following this tutorial

below is the code setup

import nltk
import os
import random
import time
import torch
import weaviate
from transformers import AutoModel, AutoTokenizer
from nltk.tokenize import sent_tokenize

torch.set_grad_enabled(False)

# udpated to use different model if desired
MODEL_NAME = "distilbert-base-uncased"
model = AutoModel.from_pretrained(MODEL_NAME)
model.to('cuda') # remove if working without GPUs
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# initialize nltk (for tokenizing sentences)
nltk.download('punkt')

# initialize weaviate client for importing and searching
client = weaviate.Client("http://localhost:8080")

def get_post_filenames(limit_objects=100):
    file_names = []
    i=0
    for root, dirs, files in os.walk("./data/20news-bydate-test"):
        for filename in files:
            path = os.path.join(root, filename)
            file_names += [path]
        
    random.shuffle(file_names)
    limit_objects = min(len(file_names), limit_objects)
      
    file_names = file_names[:limit_objects]

    return file_names

def read_posts(filenames=[]):
    posts = []
    for filename in filenames:
        f = open(filename, encoding="utf-8", errors='ignore')
        post = f.read()
        
        # strip the headers (the first occurrence of two newlines)
        post = post[post.find('\n\n'):]
        
        # remove posts with less than 10 words to remove some of the noise
        if len(post.split(' ')) < 10:
               continue
        
        post = post.replace('\n', ' ').replace('\t', ' ').strip()
        if len(post) > 1000:
            post = post[:1000]
        posts += [post]

    return posts       


def text2vec(text):
    tokens_pt = tokenizer(text, padding=True, truncation=True, max_length=500, add_special_tokens = True, return_tensors="pt")
    tokens_pt.to('cuda') # remove if working without GPUs
    outputs = model(**tokens_pt)
    return outputs[0].mean(0).mean(0).detach()

def vectorize_posts(posts=[]):
    post_vectors=[]
    before=time.time()
    for i, post in enumerate(posts):
        vec=text2vec(sent_tokenize(post))
        post_vectors += [vec]
        if i % 100 == 0 and i != 0:
            print("So far {} objects vectorized in {}s".format(i, time.time()-before))
    after=time.time()
    
    print("Vectorized {} items in {}s".format(len(posts), after-before))
    
    return post_vectors

def init_weaviate_schema():
    # a simple schema containing just a single class for our posts
    schema = {
        "classes": [{
                "class": "Post",
                "vectorizer": "none", # explicitly tell Weaviate not to vectorize anything, we are providing the vectors ourselves through our BERT model
                "properties": [{
                    "name": "content",
                    "dataType": ["text"],
                }]
        }]
    }

    # cleanup from previous runs
    client.schema.delete_all()

    client.schema.create(schema)

def import_posts_with_vectors(posts, vectors, batchsize=256):
    batch = weaviate.ObjectsBatchRequest()

    for i, post in enumerate(posts):
        props = {
            "content": post,
        }
        batch.add(props, "Post", vector=vectors[i])
        
        # when either batch size is reached or we are at the last object
        if (i !=0 and i % batchsize == 0) or i == len(posts) - 1:
            # send off the batch
            client.batch.create(batch)
            
            # and reset for the next batch
            batch = weaviate.ObjectsBatchRequest() 
    

def search(query="", limit=3):
    before = time.time()
    vec = text2vec(query)
    vec_took = time.time() - before

    before = time.time()
    near_vec = {"vector": vec.tolist()}
    res = client \
        .query.get("Post", ["content", "_additional {certainty}"]) \
        .with_near_vector(near_vec) \
        .with_limit(limit) \
        .do()
    search_took = time.time() - before

    print("\nQuery \"{}\" with {} results took {:.3f}s ({:.3f}s to vectorize and {:.3f}s to search)" \
          .format(query, limit, vec_took+search_took, vec_took, search_took))
    for post in res["data"]["Get"]["Post"]:
        print("{:.4f}: {}".format(post["_additional"]["certainty"], post["content"]))
        print('---')

# run everything
init_weaviate_schema()
posts = read_posts(get_post_filenames(4000))
vectors = vectorize_posts(posts)
import_posts_with_vectors(posts, vectors)

search("the best camera lens", 1)
search("which software do i need to view jpeg files", 1)
search("windows vs mac", 1)

the fuction below trigger errors



def text2vec(text):
    # tokens_pt = tokenizer(text, padding=True, truncation=True, max_length=500, add_special_tokens = True, return_tensors="pt")
    tokens_pt = tokenizer.encode_plus(text, add_special_tokens = True,    truncation = True, padding = "max_length", return_attention_mask = True, return_tensors = "pt")

    tokens_pt.to('cuda') # remove if working without GPUs
    outputs = model(**tokens_pt)
    return outputs[0].mean(0).mean(0).detach()

error 1

tokens_pt.to('cuda') # remove if working without GPUs AttributeError: 'dict' object has no attribute 'to'

when I comment out the GPU

#tokens_pt.to('cuda')

and run the code. I get this error

error 2

outputs = model(**tokens_pt) File "/opt/homebrew/Caskroom/miniforge/base/envs/py310a/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) TypeError: DistilBertModel.forward() got an unexpected keyword argument 'special_tokens_mask'

what is causing this errors and how can I fix it ?


Solution

  • I was not able to reproduce your errors on my environment (Ubuntu), but from what I see, I'd suggest to try adding the return_special_tokens_mask=False parameter:

    tokens_pt = tokenizer.encode_plus(
        text, 
        add_special_tokens=True,
        truncation=True,
        padding="max_length",
        return_attention_mask=True,
        return_tensors="pt",
        return_special_tokens_mask=False
    )
    

    If that fails, try to remove it explicitly:

    tokens_pt.pop("special_tokens_mask")