Search code examples
nlphuggingface-transformersnearest-neighbor

ALBERT: first word associations


I have a list of words (e.g., "apple," "banana," "mango") and would like to use ALBERT (https://huggingface.co/albert-base-v2) to identify the 10 words that are most strongly associated with each word in my list. In simple terms: "Hey ALBERT, what's the first word that comes to your mind when hearing apple/banana/mango?"

My first idea was using a prompt like "apple is related to [MASK]." but some top predictions are quite weird or not proper words like 'evalle'.

My second idea was to use a k-nearest neighbors approach. However, I don't know how to implement that into the Hugginface transformers. Is it possible to do that without fine-tuning? Do you have another idea?


Solution

  • All you need for that is to use AlBERT's embedding and decoding layers. Transformers will provide you with these easily:

    import torch
    from transformers import pipeline
    from transformers import AlbertTokenizer, AlbertModel
    
    unmasker = pipeline('fill-mask', model='albert-base-v2')
    model = unmasker.model
    tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
    

    You can just use model to see a list of the layers comprising the model. We want model.albert.embeddings.word_embeddings that transforms a word (actually a word token) into an vector, here a 128 values one. We also want model.predictions.decoder which will do the opposite operation by outputing a probability for each of the 30000 tokens that's inside the embeddings vocabulary.

    inputs = tokenizer("apple", return_tensors='pt')["input_ids"]
    # The tokenizer adds a start and end token, but we only want the middle one:
    apple_embed = model.albert.embeddings.word_embeddings(inputs)[0,1,:]
    # Taken the ten top probabilities for this encoding:
    topk = torch.topk(model.predictions.decoder(apple_embed), k=10)
    
    # Output the result:
    for prob, ind in zip(topk.values, topk.indices):
        print(prob.item(), tokenizer.decode(ind))
    

    And you get a critical look into the importance we give to natural products as compared to trademarks:

    0.11230556666851044 apple
    0.07886222004890442 apple
    0.05940583348274231 atari
    0.05769592523574829 macintosh
    0.057240456342697144 blackberry
    0.054983582347631454 amazon
    0.05450776219367981 iphone
    0.05447978526353836 mango
    0.05426642298698425 itunes
    0.05226457118988037 raspberry