I have a list of words (e.g., "apple," "banana," "mango") and would like to use ALBERT (https://huggingface.co/albert-base-v2) to identify the 10 words that are most strongly associated with each word in my list. In simple terms: "Hey ALBERT, what's the first word that comes to your mind when hearing apple/banana/mango?"
My first idea was using a prompt like "apple is related to [MASK]." but some top predictions are quite weird or not proper words like 'evalle'.
My second idea was to use a k-nearest neighbors approach. However, I don't know how to implement that into the Hugginface transformers. Is it possible to do that without fine-tuning? Do you have another idea?
All you need for that is to use AlBERT's embedding and decoding layers. Transformers will provide you with these easily:
import torch
from transformers import pipeline
from transformers import AlbertTokenizer, AlbertModel
unmasker = pipeline('fill-mask', model='albert-base-v2')
model = unmasker.model
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
You can just use model
to see a list of the layers comprising the model. We want model.albert.embeddings.word_embeddings
that transforms a word (actually a word token) into an vector, here a 128 values one. We also want model.predictions.decoder which will do the opposite operation by outputing a probability for each of the 30000 tokens that's inside the embeddings vocabulary.
inputs = tokenizer("apple", return_tensors='pt')["input_ids"]
# The tokenizer adds a start and end token, but we only want the middle one:
apple_embed = model.albert.embeddings.word_embeddings(inputs)[0,1,:]
# Taken the ten top probabilities for this encoding:
topk = torch.topk(model.predictions.decoder(apple_embed), k=10)
# Output the result:
for prob, ind in zip(topk.values, topk.indices):
print(prob.item(), tokenizer.decode(ind))
And you get a critical look into the importance we give to natural products as compared to trademarks:
0.11230556666851044 apple
0.07886222004890442 apple
0.05940583348274231 atari
0.05769592523574829 macintosh
0.057240456342697144 blackberry
0.054983582347631454 amazon
0.05450776219367981 iphone
0.05447978526353836 mango
0.05426642298698425 itunes
0.05226457118988037 raspberry