In order to minimize the function x^2+y^2, I tried to implement the Adam optimizer from scratch with JAX:
@jax.jit
def fit(X, batches, params=[0.001, 0.9, 0.99, 1e-8]):
global fun
# batches is array containing n batches
@jax.jit
def adam_update(X, t, grads, m, v, alpha, b1, b2, epsilon):
m = b1*m + (1-b1)*grads
v = b2*v + (1-b2)*grads**2
m_hat = m / (1-b1**t)
v_hat = v / (1-b2**t)
X = X - alpha*(1-b2**t)**(1/2)*m_hat/(1-b1**t)/((v_hat)**(1/2)+epsilon)
return [X, m, v]
dim=jnp.shape(X)[0]
params = jnp.array(params)
alpha = params[0]
b1=params[1]
b2=params[2]
epsilon=params[3]
adam_update = jax.jit(partial(adam_update, alpha=alpha, b1=b1, b2=b2, epsilon=epsilon))
m=jnp.zeros(dim)
v=jnp.zeros(dim)
for t, batch in enumerate(batches):
fun_ = jax.jit(partial(fun, batch=batch))
grads = jax.grad(fun_)(X)
X, m, v = adam_update(X, t+1, grads, m, v)
return X
With JAX I could parallelize this simply with jax.pmap
, however it would only be parallelized over the 8 GPUs, instead over all GPU cores. Is there a way too parallelize this code over all cores?
Can it be that all cores of one GPU are miraculously used upon using @jax.jit
. Also, why does it need 200 seconds for compiling for 1000 iterations, while the optax
-Adam optimizer does not take so long too compile?
Can it be that all cores of one GPU are miraculously used upon using
@jax.jit
In general, yes. For computations on a single device, the XLA GPU compiler will use all available cores of the GPU to complete a computation.
Also, why does it need 200 seconds for compiling for 1000 iterations, while the
optax
-Adam optimizer does not take so long too compile?
This is because you are JIT-compiling a Python for
-loop. Python loops within JIT are unrolled by JAX into a linear program (see JAX Sharp Bits: Control Flow), and compilation time grows with the size of the program.
By contrast, the optax quick-start recommends JIT-compiling the step
function, but does not JIT-compile the fitting loop. This would lead to much faster compilation times than the pattern used in your code, where the full for
-loop is within a JIT-compiled function.