I'm trying to run this simple introduction to score-based generative modeling. The code is using flax.optim
, which seems to be moved to optax
meanwhile (https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/optax_update_guide.html).
I've made a copy of the colab code with the changes I think needed to be made (I'm only unsure how I need to replace optimizer = flax.jax_utils.replicate(optimizer)
).
Now, in the training section, I get the error
pmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
at the line loss, params, opt_state = train_step_fn(step_rng, x, params, opt_state)
. This obviously comes from the return jax.pmap(step_fn, axis_name='device')
in the "Define the loss function" section.
How can I fix this error? I've googled it, but have no idea what's going wrong here.
This happens because you are passing a scalar argument to a pmapped function. For example:
import jax
func = lambda x: x ** 2
pfunc = jax.pmap(func)
pfunc(1.0)
# ValueError: pmap was requested to map its argument along axis 0, which implies
# that its rank should be at least 1, but is only 0 (its shape is ())
If you want to operate on a scalar, you should use the function without wrapping it in pmap
:
func(1.0)
# 1.0
Alternatively, if you want to use pmap
, you should operate on an array whose leading dimension matches the number of devices:
num_devices = len(jax.devices())
x = jax.numpy.arange(num_devices)
pfunc(x)
# Array([ 0, 1, 4, 9, 16, 25, 36, 49], dtype=int32)