Search code examples
pytorchlarge-language-modelcausal-inferencegemma

Understanding Change in Output Tensor Shape during Causal Inference in Gemma Model's MLP Block


I am printing the shape of the output tensor of the MLP block during causal inference of Gemma model for a given input. What I observe is that during first token generation, the shape is (batch_size, input_seq_length, hidden_size), but from the subsequent token generations, the shape changes to (batch_size, 1, hidden_size). For example, consider a given input sequence of length 5 and a desired output length of 2:

enter image description here

Why does this happen? My understanding is that during the first token inference, the model processes the entire input sequence through a Gemma_Decoder block, generating a <SOS> (Start of Sentence) token while obtaining token embeddings for each input sequence. However, for subsequent token generations, it only utilizes the last token generated to produce a new token, retrieving information about previous tokens through the kv cache built over time during inference.

I would love to understand it in more depth, so if anyone can provide with links to resources, it would be of great help.


Solution

  • The short answer is the model internals are caching previous time-steps to avoid re-computing things unnecessarily. Do a word search for cache in their codebase and see what comes up.

    Longer answer:

    Imagine you do autoregressive generation naively. You start with an input of length 5. You do a forward pass, generate a new token. You concatenate the new token to your input sequence. Now you have a sequence of length 6. You start again, doing a full forward pass with an input of size 6, predict a new token, and repeat.

    This is extremely inefficient. If your transformer is using upper triangular masking (standard for causal language models), your activations at sequence position n only depend on the previous tokens 0, ... n-1. This means you can compute the activations for that token once, cache them, and re-use the cached values for computing new tokens.

    When computing a new token, you only need to reference previous token values for the attention operation (feed forward, norm, etc do not operate across tokens). So really you only need to cache the k and v activations from previous timesteps - hence the term KV cache.

    When you run inference on a new timestep, you reference the cached KV values from old tokens for the attention steps. Then you update the KV cache with KV values from the new token.

    This means you don't have to run a full forward pass on the full expanding sequence every time.