Search code examples
pythonmachine-learningoptimizationparallel-processingjax

Parallelize with JAX over all GPU cores


In order to minimize the function x^2+y^2, I tried to implement the Adam optimizer from scratch with JAX:

@jax.jit
def fit(X, batches, params=[0.001, 0.9, 0.99, 1e-8]):
    global fun

    # batches is array containing n batches

    @jax.jit
    def adam_update(X, t, grads, m, v, alpha, b1, b2, epsilon):
        m = b1*m + (1-b1)*grads
        v = b2*v + (1-b2)*grads**2
        m_hat = m / (1-b1**t)
        v_hat = v / (1-b2**t)
        X = X - alpha*(1-b2**t)**(1/2)*m_hat/(1-b1**t)/((v_hat)**(1/2)+epsilon)
        return [X, m, v]

    dim=jnp.shape(X)[0]

    params = jnp.array(params)
    alpha = params[0]
    b1=params[1]
    b2=params[2]
    epsilon=params[3]
    
    adam_update = jax.jit(partial(adam_update, alpha=alpha, b1=b1, b2=b2, epsilon=epsilon))

    m=jnp.zeros(dim)
    v=jnp.zeros(dim)

    for t, batch in enumerate(batches):
        fun_ = jax.jit(partial(fun, batch=batch))
        grads = jax.grad(fun_)(X)
        X, m, v = adam_update(X, t+1, grads, m, v)
    return X

With JAX I could parallelize this simply with jax.pmap, however it would only be parallelized over the 8 GPUs, instead over all GPU cores. Is there a way too parallelize this code over all cores?

Can it be that all cores of one GPU are miraculously used upon using @jax.jit. Also, why does it need 200 seconds for compiling for 1000 iterations, while the optax-Adam optimizer does not take so long too compile?


Solution

  • Can it be that all cores of one GPU are miraculously used upon using @jax.jit

    In general, yes. For computations on a single device, the XLA GPU compiler will use all available cores of the GPU to complete a computation.

    Also, why does it need 200 seconds for compiling for 1000 iterations, while the optax-Adam optimizer does not take so long too compile?

    This is because you are JIT-compiling a Python for-loop. Python loops within JIT are unrolled by JAX into a linear program (see JAX Sharp Bits: Control Flow), and compilation time grows with the size of the program.

    By contrast, the optax quick-start recommends JIT-compiling the step function, but does not JIT-compile the fitting loop. This would lead to much faster compilation times than the pattern used in your code, where the full for-loop is within a JIT-compiled function.