Search code examples
machine-learningnlplarge-language-modeltransformer-modelmeta-learning

Understanding the results of Transformers Learn In Context with Gradient Descent


I'm trying to implement this paper: https://arxiv.org/pdf/2212.07677

(Here's their code): https://github.com/google-research/self-organising-systems/tree/master/transformers_learn_icl_by_gd

I'm struggling to match their experimental results. Specifically, on their simplest GD model (a single layer with a single head and no softmax), they obtain a constant low loss of roughly 0.20 on their test data. I don't quite understand why this is the case, conceptually.

As I understand it, this model only does a single iteration of gradient descent on the data, so why would it reach such a low loss? And why would the loss be constant/near constant over training steps? Aren't we training the learning rate in the GD model?


Solution

  • What data are you using in your replication? As far as I can tell this paper does not mention explicitly the parameters of the data used for the particular result you are trying to replicate. Indeed, it tests a variety of alpha values for the distributions used in figure 6. It is feasible for the loss to be low even after one step of GD if the alpha value is low. If you find the same trends in relative behavior of GD and transformer layers, I don't think it's important to match the exact loss values.