Search code examples
pythongraphdeep-learningpytorchjax

Jax Implementation of function similar to Torch's 'Scatter'


For graph learning purposes, I am trying to implement a global sum batching function, that takes as inputs batched graph representations 'x' of size (n x d) and a corresponding vector of batches (n x 1). I then want to compute the sum over all graph representations for each batch. Here is a graphical representation: torch's scatter function

This is my current attempt:

def global_sum_pool(x, batch):
    graph_reps = []
    i = 0
    n = jnp.max(batch)
    while True:
        ind = jnp.where(batch == i, True, False).reshape(-1, 1)
        ind = jnp.tile(ind, x.shape[1])
        x_ind = jnp.where(ind == True, x, 0.0)
        graph_reps.append(jnp.sum(x_ind, axis=0))
        if i == n:
            break
        i += 1
    return jnp.array(graph_reps)

I get the following exception on the line if i == n:

jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function make_step at /venvs/jax_env/lib/python3.11/site-packages/equinox/_jit.py:37 for jit. 

I understand this is due to the fact that at compile time, Jax does not a priori know the max value of the 'batch' array and hence cannot allocate memory. Does anyone know a workaround or different implementation?


Solution

  • Rather than implementing this via a for loop, you should use JAX's built-in scatter operator. The most convenient interface for this is the Array.at syntax. If I understand your goal correctly, it might look something like this:

    import jax.numpy as jnp
    import numpy as np
    
    # Generate some data
    num_batches = 4
    n = 10
    d = 3
    x = np.random.randn(n, d)
    ind = np.random.randint(low=0, high=num_batches, size=(n,))
    
    #Compute the result with jax.lax.scatter
    result = jnp.zeros((num_batches, d)).at[ind].add(x)
    print(result.shape)
    # (4, 3)