Search code examples
pythonpandasdataframeapply

imported function from another file is slow


I have a function like this which loads fine if I am running it directly without importing it from another helper.py file.

I'm not sure what's causing the slow loading.

helper_file.py:

from transformers import BertTokenizer 
def embed_answers(ans, length): 
    sentence_embeddings = []
    embeddings = BertTokenizer.from_pretrained('...')
    sentence_embeddings.extend(embeddings.encode(ans, max_length=length, padding='max_length') 
    return sentence embeddings 

def get_dataset(vec_type): 
    vec_dict = {"large": 1000, "medium": 500, "small": 150} 
    if vec_type.lower() not in vec_dict: 
        raise Exception("Invalid vector type!")
    df = pd.read_hdf('...', mode='r')
    vec_length = vec_dict[vec_type]
    df['embeddings_col'] = df['answer'].apply(embed_answers, length=vec_length) 
    return df 

When I import and call get_dataset from a main.py file it crashes without loading fully. But directly running the function from main.py is fine.

Not sure what's the issue, appreciate any ideas, thanks!


Solution

  • I don't know if it can solve your problem but load the tokenizer each time you call embed_answers is a waste of resources (time and memory). Try to take advantage of vectorization.

    from transformers import BertTokenizer
    
    def embed_answers(ans, length): 
        embeddings = BertTokenizer.from_pretrained('...')
        inputs = embeddings(ans, max_length=length, padding='max_length') 
        return inputs['input_ids']
    
    def get_dataset(vec_type): 
        vec_dict = {"large": 1000, "medium": 500, "small": 150} 
        if vec_type.lower() not in vec_dict: 
            raise Exception("Invalid vector type!")
        df = pd.read_hdf('...', mode='r')
        vec_length = vec_dict[vec_type]
        df['embeddings_col'] = embed_answers(df['answer'].tolist(), length=vec_length) 
        return df
    

    Demo:

    def embed_answers(ans, length): 
        embeddings = BertTokenizer.from_pretrained('bert-base-cased')
        inputs = embeddings(ans, max_length=length, padding='max_length') 
        return inputs['input_ids']
    
    data = ["Hello, my dog is cute", "my cat is a daemon"]
    df = pd.DataFrame({'answer': data})
    df['embeddings_col'] = embed_answers(df['answer'].tolist(), length=10)
    

    Output:

    >>> df
                      answer                                        embeddings_col
    0  Hello, my dog is cute  [101, 8667, 117, 1139, 3676, 1110, 10509, 102, 0, 0]
    1     my cat is a daemon  [101, 1139, 5855, 1110, 170, 5358, 25027, 102, 0, 0]