I have a function like this which loads fine if I am running it directly without importing it from another helper.py
file.
I'm not sure what's causing the slow loading.
helper_file.py
:
from transformers import BertTokenizer
def embed_answers(ans, length):
sentence_embeddings = []
embeddings = BertTokenizer.from_pretrained('...')
sentence_embeddings.extend(embeddings.encode(ans, max_length=length, padding='max_length')
return sentence embeddings
def get_dataset(vec_type):
vec_dict = {"large": 1000, "medium": 500, "small": 150}
if vec_type.lower() not in vec_dict:
raise Exception("Invalid vector type!")
df = pd.read_hdf('...', mode='r')
vec_length = vec_dict[vec_type]
df['embeddings_col'] = df['answer'].apply(embed_answers, length=vec_length)
return df
When I import and call get_dataset
from a main.py
file it crashes without loading fully. But directly running the function from main.py
is fine.
Not sure what's the issue, appreciate any ideas, thanks!
I don't know if it can solve your problem but load the tokenizer each time you call embed_answers
is a waste of resources (time and memory). Try to take advantage of vectorization.
from transformers import BertTokenizer
def embed_answers(ans, length):
embeddings = BertTokenizer.from_pretrained('...')
inputs = embeddings(ans, max_length=length, padding='max_length')
return inputs['input_ids']
def get_dataset(vec_type):
vec_dict = {"large": 1000, "medium": 500, "small": 150}
if vec_type.lower() not in vec_dict:
raise Exception("Invalid vector type!")
df = pd.read_hdf('...', mode='r')
vec_length = vec_dict[vec_type]
df['embeddings_col'] = embed_answers(df['answer'].tolist(), length=vec_length)
return df
Demo:
def embed_answers(ans, length):
embeddings = BertTokenizer.from_pretrained('bert-base-cased')
inputs = embeddings(ans, max_length=length, padding='max_length')
return inputs['input_ids']
data = ["Hello, my dog is cute", "my cat is a daemon"]
df = pd.DataFrame({'answer': data})
df['embeddings_col'] = embed_answers(df['answer'].tolist(), length=10)
Output:
>>> df
answer embeddings_col
0 Hello, my dog is cute [101, 8667, 117, 1139, 3676, 1110, 10509, 102, 0, 0]
1 my cat is a daemon [101, 1139, 5855, 1110, 170, 5358, 25027, 102, 0, 0]