When optimizing weights and biases of a model, does it make sense to replace:
for _ in range(epochs):
w, b = step(w, b)
With:
w, b = lax.fori_loop(0, epochs, lambda wb: step(wb[0], wb[1]), (w, b))
If I understand correctly, this means that the entire training process can then be a single compiled JAX function (takes in training data, outputs optimized weights and biases).
Is this a standard approach? What are the tradeoffs to consider?
It's fine to train your model with fori_loop
, particularly for simple models. It may be slightly faster, but in general XLA won't fuse operations across different loop steps. It's also not possible to return early within a fori_loop
when you reach a certain loss threshold (though you could do that with while_loop
if you wish).
For more complicated models, you often will want to do some sort of I/O at every step (e.g. loading new training data, logging fit parameters, etc.) While this is possible to do within fori_loop
via jax.experimental.io_callback
, it is somewhat less convenient than doing it directly from the host within a Python for
loop, so in general users tend to use for
loops for their training iterations.