I wanted to do a pmap of a given function, with 2D arrays that might (or might not) contain nan values. That function must then apply some operations to the finite values that exist in each row (toy examples at the end of the post).
I know how many points (per row) contain NaNs, even before I /jax.jit/ anything. Thus, I should be able to:
import jax.numpy as jnp
inds = jnp.where(jnp.isfinite(line), size= Finite_points_number)
but I am not able to pass the size of the elements into the pmap-ed function.
I have tried to:
i) pmap over over the list with the number of good points per row:
data_array = jnp.array([
[1,2,3,4],
[4,5,6, jnp.nan]
]
)
sizes = jnp.asarray((4, 3)) # Number of valid points per row
def jitt_function(line, N):
"""
Over-simplified function to showcase the problem
"""
inds = jnp.where(jnp.isfinite(line), size=N)
return jnp.sum(line[inds])
pmap_func = jax.pmap(jitt_function,
in_axes=(0, 0)
)
pmap_func(data_array, sizes)
and it fails with
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations. The error occurred while tracing the function jitt_function at [...] for pmap. This concrete value was not available in Python because it depends on the value of the argument 'N'.
ii) I have also tried to turn the number of points (N) into a static argument:
jitt_function = jax.jit(jitt_function, static_argnames=("N",))
pmap_func = jax.pmap(jitt_function,
in_axes=(0, 0)
)
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 1) of type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> for function jitt_function is non-hashable.
Even if I managed to transform this into a static argument, I would still need to "know" the line number, so that I could access the correct number of good points.
Question: Is there any way for me to do this within jax?
If you have a different number of points per pmap
batch, then the number is not static. The result of your intended operation would be a ragged array (i.e. a 2D array whose rows have differing numbers of elements) and ragged arrays are not supported in JAX.
If you actually have a static number of elements—meaning an equal number in every batch—then you can use the size
argument of jnp.where
to do this computation. It might look something like this:
from functools import partial
def jitt_function(line, N):
"""
Over-simplified function to showcase the problem
"""
inds = jnp.where(jnp.isfinite(line), size=N, fill_value=0)
return jnp.sum(line[inds])
pmap_func = jax.pmap(partial(jitt_function, N=4))
pmap_func(data_array)
If you have fewer than the specified number of entries in each batch, then one option is to specify the fill_value
argument to jnp.where
to pad the output. In this case, since you are taking the sum along each dimension, a fill value of zero returns the expected result.