Search code examples
pythonneural-networkjaxpmap

jax: How do we solve the error: pmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0?


I'm trying to run this simple introduction to score-based generative modeling. The code is using flax.optim, which seems to be moved to optax meanwhile (https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/optax_update_guide.html).

I've made a copy of the colab code with the changes I think needed to be made (I'm only unsure how I need to replace optimizer = flax.jax_utils.replicate(optimizer)).

Now, in the training section, I get the error

pmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

at the line loss, params, opt_state = train_step_fn(step_rng, x, params, opt_state). This obviously comes from the return jax.pmap(step_fn, axis_name='device') in the "Define the loss function" section.

How can I fix this error? I've googled it, but have no idea what's going wrong here.


Solution

  • This happens because you are passing a scalar argument to a pmapped function. For example:

    import jax
    func = lambda x: x ** 2
    pfunc = jax.pmap(func)
    
    pfunc(1.0)
    # ValueError: pmap was requested to map its argument along axis 0, which implies
    # that its rank should be at least 1, but is only 0 (its shape is ())
    

    If you want to operate on a scalar, you should use the function without wrapping it in pmap:

    func(1.0)
    # 1.0
    

    Alternatively, if you want to use pmap, you should operate on an array whose leading dimension matches the number of devices:

    num_devices = len(jax.devices())
    x = jax.numpy.arange(num_devices)
    pfunc(x)
    # Array([ 0,  1,  4,  9, 16, 25, 36, 49], dtype=int32)