Search code examples
machine-learningnlpbeam-search

How does Beam Search operate on the output of The Transformer?


According to my understanding (please correct me if I'm wrong), Beam Search is BFS where it only explores the "graph" of possibilities down b the most likely options, where b is the beam size.

To calculate/score each option, especially for the work that I'm doing which is in the field of NLP, we basically calculate the score of a possibility by calculating the probability of a token, given everything that comes before it.

This makes sense in a recurrent architecture, where you simply run the model you have with your decoder through the best b first tokens, to get the probabilities of the second tokens, for each of the first tokens. Eventually, you get sequences with probabilities and you just pick the one with the highest probability.

However, in the Transformer architecture, where the model doesn't have that recurrence, the output is the entire probability for each word in the vocabulary, for each position in the sequence (batch size, max sequence length, vocab size). How do I interpret this output for Beam Search? I can get the encodings for the input sequence, but since there isn't that recurrence of using the previous output as input for the next token's decoding, how do I go about calculating the probability of all the possible sequences that stems from the best b tokens?


Solution

  • The beam search works exactly in the same as with the recurrent models. The decoder is not recurrent (it's self-attentive), but it is still auto-regressive, i.e., generating a token is conditioned on previously generated tokens.

    At the training time, the self-attention is masked, such that in only attend to words to the left from the word that is currently generated. It simulates the setup you have at inference time when you indeed only have the left context (because the right context has not been generated yet).

    The only difference is that in the RNN decoder, you only use the last RNN state in every beam search step. With the Transformer, you always need to keep the entire hypothesis and do the self-attention over the entire left context.