Search code examples
pytorchnlpbert-language-modelfine-tuning

How do I fine tune BERT's self attention mechanism?


My goal is to fine tune BERT's self attention so that I can see to what extent two random sentences in a document (with positional encoding) rely on each other contextually.

Many explanations and article that I see talk about the implementation of self-attention but do not mention how to further train self attention.

Here is what I'm thinking of doing to train BERT's self attention:

  1. Use some kind of word to vector algorithm and vectorize all the words in the article.

  2. Add positional encoding to each sentence [where the sentence is an array of vectors (words)] using a sinusoidal function.

  3. Make matrix of each sentence concatenated with every other sentence

  4. For each sentence-sentence pair, iterate through each words masking them. The model must guess the word based on context; back prop is based on accuracy of the guess.

  5. The finished model should be able to take in an arbitrary sentence-sentence pair and output an attention matrix.

I'm not sure if such a method is the right one for fine tuning (or if this even counts as continuing pre-training) a self attention mechanism, or if BERT is even the best model to train a self attention function on.

I'm obviously very new to fine tuning LLMs, so any guidance would be greatly appreciated!


Solution

  • Hugginface provides a model class, that you can use for your task: https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertForNextSentencePrediction. See also the small example provided in the link. You can use the logits in the output to create your attention matrix.

    This is just an inference task. Fine-tuning means to do further training on custom data.