Search code examples
numpyoptimizationjax

Error using JAX, Array slice indices must have static start/stop/step


I'll be happy to help you with your code. If I understand correctly, you want to create a 2D Gaussian patch for each value in the darkField array. The size of the patch should ideally be calculated as 2 * np.ceil(3 * sigma) + 1, where sigma is the corresponding value from the darkField array. You have fixed the size value to 10 in your example to avoid errors.

Once the Gaussian patch is normalized to 1, you want to multiply it by the corresponding value from the intensityRefracted2DF array to obtain the generated blur. Finally, you want to add this blur patch to the intensityRefracted3 array.

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from jax.scipy.signal import convolve2d
from functools import partial

@partial(jax.jit,static_argnums=(1,))
def gaussian_shape(sigma, size):
    """
    Generate a Gaussian shape.

    Args:
        sigma (float or 2D numpy array): Standard deviation(s) of the Gaussian shape.
        size (int): Size of the Gaussian shape.

    Returns:
        exponent (2D numpy array): Gaussian shape.
    """

    x = jnp.arange(0, size) - jnp.floor(size / 2)
    exponent = jnp.exp(-(x ** 2) / (2 * sigma ** 2))
    exponent = jnp.outer(exponent, exponent)
    exponent /= jnp.sum(exponent)
    return exponent

@partial(jax.jit)
def apply_dark_field(i, j, intensityRefracted2DF, intensityRefracted3, darkField):
    currDF_ij=darkField[i,j]
    patch = gaussian_shape(currDF_ij,10)
    size2 = patch.shape[0] // 2
    patch = patch * intensityRefracted2DF[i, j]


    intensityRefracted3 = intensityRefracted3.at[i - size2:i + size2 + 1, j - size2:j + size2 + 1].add(patch * intensityRefracted2DF[i, j])
    # intensityRefracted3 = jax.ops.index_add(intensityRefracted3, (i, j), intensityRefracted2DF[i, j] * (darkField[i, j] == 0))
    return intensityRefracted3

@jax.jit
def darkFieldLoop(intensityRefracted2DF, intensityRefracted3, darkField):
    currDF = jnp.zeros_like(intensityRefracted3)
    currDF = jnp.where(intensityRefracted2DF!=0,darkField,0)

    i = jnp.nonzero(currDF,size=currDF.shape[0])
    indices_i=i[0]
    indices_j=i[1]
    intensityRefracted3 = jnp.zeros_like(intensityRefracted3)

    intensityRefracted3 = jax.vmap(apply_dark_field, in_axes=(0, 0, None, None, None))(indices_i, indices_j, intensityRefracted2DF, intensityRefracted3, darkField)

    return intensityRefracted3

intensityRefracted2DF = np.random.rand(10,10)
intensityRefracted3 = np.zeros((10, 10))
darkField = np.random.rand(10, 10)

a=darkFieldLoop(intensityRefracted2DF,intensityRefracted3,darkField)

for i in range(a.shape[0]):
    plt.imshow(a[i])
    plt.show()

And there is the error message :

IndexError: Array slice indices must have static start/stop/step to   be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=3/0)>, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=3/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

I've also try to put i,j into static_argnums using partial

@partial(jax.jit, static_argnums=(0,1))
def apply_dark_field(i, j, intensityRefracted2DF, intensityRefracted3, darkField):

and there is the error:

ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 0) of type <class 'jax._src.interpreters.batching.BatchTracer'> for function apply_dark_field is non-hashable.

Solution

  • The issue comes from the fact that JAX arrays cannot have a dynamic shape, and so dynamic start & end indices cannot be used in indexing expressions.

    Your solution of marking i and j as static would work, except that you are vmapping across these values, so by definition they cannot be static.

    The best solution here is probably to use lax.dynamic_slice and lax.dynamic_update_slice, which are operations designed exactly for the case that you have (where indices are dynamic, but slice sizes are static).

    You can replace this line:

    intensityRefracted3 = intensityRefracted3.at[i - size2:i + size2 + 1, j - size2:j + size2 + 1].add(patch * intensityRefracted2DF[i, j])
    

    with this:

    start_indices = (i - size2, j - size2)
    update = jax.lax.dynamic_slice(intensityRefracted3, start_indices, patch.shape)
    update += patch * intensityRefracted2DF[i, j]
    intensityRefracted3 = jax.lax.dynamic_update_slice(
        intensityRefracted3, update,  start_indices)
    

    and it should work correctly with dynamic i and j. Though you should be careful, because if any of the specified indices are out-of-bounds, dynamic_slice and dynamic_update_slice will clip them into the valid range.