Search code examples
machine-learningnlpbert-language-model

Is it possible to fine-tune BERT to do retweet prediction?


I want to build a classifier that predicts if user i will retweet tweet j.

The dataset is huge, it contains 160 million tweets. Each tweet comes along with some metadata(e.g. does the retweeter follow the user of the tweet).

the text tokens for a single tweet is an ordered list of BERT ids. To get the embedding of the tweet, you just use the ids (So it is not text)

Is it possible to fine-tune BERT to do the prediction? if yes, what do courses/sources do you recommend to learn how to fine-tune? (I'm a beginner)

I should add that the prediction should be a probability.

If it's not possible, I'm thinking of converting the embeddings back to text then using some arbitrary classifier that I'm going to train.


Solution

  • You can fine-tune BERT, and you can use BERT to do retweet prediction, but you need more architecture in order to predict if user i will retweet tweet j.

    Here is an architecture off the top of my head.

    enter image description here

    At a high level:

    1. Create a dense vector representation (embedding) of user i (perhaps containing something about the user's interests, such as sports).
    2. Create an embedding of tweet j.
    3. Create an embedding of the combination of the first two embeddings together, such as with concatenation or hadamard product.
    4. Feed this embedding through a NN that performs binary classification to predict retweet or non-retweet.

    Let's break this architecture down by item.

    To create an embedding of user i, you will need to create some kind of neural network that accepts whatever features you have about the user and produces a dense vector. This part is the most difficult component of the architecture. This area is not in my wheelhouse, but a quick google search for "user interest embedding" brings up this research paper on an algorithm called StarSpace. It suggests that it can "obtain highly informative user embeddings according to user behaviors", which is what you want.

    To create an embedding of tweet j, you can use any type of neural network that takes tokens and produces a vector. Research prior to 2018 would have suggested using an LSTM or a CNN to produce the vector. However, BERT (as you mentioned in your post) is the current state-of-the-art. It takes in text (or text indices) and produces a vector for each token; one of those tokens should have been the prepended [CLS] token, which commonly is taken to be the representation of the whole sentence. This article provides a conceptual overview of the process. It is in this part of the architecture that you can fine-tune BERT. This webpage provides concrete code using PyTorch and the Huggingface implementation of BERT to do this step (I've gone through the steps and can vouch for it). In the future, you'll want to google for "BERT single sentence classification".

    To create an embedding representing the combination of user i and tweet j, you can do one of many things. You can simply concatenate them together into one vector; so if user i is an M-dimensional vector and tweet j is an N-dimensional vector, then the concatenation produces an (M+N)-dimensional vector. An alternative approach is to compute the hadamard product (element-wise multiplication); in this case, both vectors must have the same dimension.

    To make the final classification of retweet or not-retweet, build a simple NN that takes the combination vector and produces a single value. Here, since you are doing binary classification, a NN with a logistic (sigmoid) function would be appropriate. You can interpret the output as the probability of retweeting, so a value above 0.5 would be to retweet. See this webpage for basic details on building a NN for binary classification.

    In order to get this whole system to work, you need to train it all together end-to-end. That is, you have to get all the pieces hooked up first and train it rather than training the components separately.

    Your input dataset would look something like this:

    user                          tweet                  retweet?
    ----                          -----                  --------
    20 years old, likes sports    Great game             Y
    30 years old, photographer    Teen movie was good    N 
    

    If you want an easier route where there is no user personalization, then just leave out the components that create an embedding of user i. You can use BERT to build a model to determine if the tweet is retweeted without regard to user. You can again follow the links I mentioned above.