Search code examples
huggingface-transformers

AutoModelForCausalLM for extracting text embeddings


I have an application that uses AutoModelForCausalLM to answer questions. I need to use this same model to extract embeddings from text. I know that I can use SentenceTransformer but that would mean that I load twice the weights of the model. How would I use AutoModelForCausalLM to extract embeddings from text?


Solution

  • Warning: As mentioned before in the comments, you need to check if the produced sentence embeddings are meaningful, this is required because the model you are using wasn't trained to produce meaningful sentence embeddings (check this StackOverflow answer for further information).

    Putting that aside, the following code shows you a way to retrieve sentence embeddings from databricks/dolly-v2-3b. It uses a weighted-mean-pooling approach because your model is a decoder with left-to-right attention. The idea behind this approach is that the tokens at the end of the sentence should contribute more than the tokens at the beginning of the sentence because their weights are contextualized with the previous tokens, while the tokens at the beginning have far less context representation.

    import torch
    from transformers import AutoTokenizer, AutoModelForCausalLM
    
    model_id = "databricks/dolly-v2-3b"
    
    t = AutoTokenizer.from_pretrained(model_id)
    m = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
    m.eval()
    
    
    texts = [
        "this is a test",
        "this is another test case with a different length",
    ]
    t_input = t(texts, padding=True, truncation=True, return_tensors="pt")
    
    
    with torch.no_grad():
        last_hidden_state = m(**t_input, output_hidden_states=True).hidden_states[-1]
    
    
    weights_for_non_padding = t_input.attention_mask * torch.arange(start=1, end=last_hidden_state.shape[1] + 1).unsqueeze(0)
    
    sum_embeddings = torch.sum(last_hidden_state * weights_for_non_padding.unsqueeze(-1), dim=1)
    num_of_none_padding_tokens = torch.sum(weights_for_non_padding, dim=-1).unsqueeze(-1)
    sentence_embeddings = sum_embeddings / num_of_none_padding_tokens
    
    print(t_input.input_ids)
    print(weights_for_non_padding)
    print(num_of_none_padding_tokens)
    print(sentence_embeddings.shape)
    

    Output:

    tensor([[2520,  310,  247, 1071,    0,    0,    0,    0,    0],
            [2520,  310, 1529, 1071, 1083,  342,  247, 1027, 2978]])
    tensor([[1, 2, 3, 4, 0, 0, 0, 0, 0],
            [1, 2, 3, 4, 5, 6, 7, 8, 9]])
    tensor([[10],
            [45]])
    torch.Size([2, 2560])