Search code examples
jaxflax

Should models be trained using fori_loop?


When optimizing weights and biases of a model, does it make sense to replace:

for _ in range(epochs):
    w, b = step(w, b)

With:

w, b = lax.fori_loop(0, epochs, lambda wb: step(wb[0], wb[1]), (w, b))

If I understand correctly, this means that the entire training process can then be a single compiled JAX function (takes in training data, outputs optimized weights and biases).

Is this a standard approach? What are the tradeoffs to consider?


Solution

  • It's fine to train your model with fori_loop, particularly for simple models. It may be slightly faster, but in general XLA won't fuse operations across different loop steps. It's also not possible to return early within a fori_loop when you reach a certain loss threshold (though you could do that with while_loop if you wish).

    For more complicated models, you often will want to do some sort of I/O at every step (e.g. loading new training data, logging fit parameters, etc.) While this is possible to do within fori_loop via jax.experimental.io_callback, it is somewhat less convenient than doing it directly from the host within a Python for loop, so in general users tend to use for loops for their training iterations.