Search code examples
pythongradient-descentjax

Gradient Accumulation with JAX


I made a simple script to try to do gradient accumulation with JAX. The idea is to have large batch size (e.g. 64) that are split in small chunks (e.g. 4) that fit in the GPU's memory. For each chunck, the resulting gradient, stored in a pytree, is added to the current batch gradient. The update is done only when all chunks of the large batch are computed. In this particular example, we simply try to fit random 512-dimensional vectors to random booleans with a linear layer. Here is the script:

import jax
import jax.numpy as jnp
from jax import jit, random
from jax.experimental import optimizers
from functools import partial
from jax.nn.initializers import normal, zeros
from typing import Callable
from dataclasses import dataclass

@dataclass
class Jax_model:
    init_fun: Callable
    apply_fun: Callable


def Dense(input_size: int, output_size: int, init_kernel=normal(), init_bias=zeros):

    def init_fun(key):
        key, sub_key1, sub_key2 = jax.random.split(key, 3)
        params = {
            'I': init_kernel(sub_key1, (input_size, output_size) ),
            'I_b': init_bias(sub_key2, (1,output_size) ),
        }
        return params

    def apply_fun(params, inputs):
        I, I_b, = params['I'], params['I_b']
        logits = inputs @ I + I_b
        return logits

    return Jax_model(init_fun, apply_fun)


def divide_pytree(pytree, div):
    for pt in jax.tree_util.tree_leaves(pytree):
        pt = pt / div
    return pytree


def add_pytrees(pytree1, pytree2):
    for pt1, pt2 in zip( jax.tree_util.tree_leaves(pytree1), jax.tree_util.tree_leaves(pytree2) ):
        pt1 = pt1 + pt2
    return pytree1


rng_key = random.PRNGKey(42)
batch_size = 64
accumulation_size = 4
model_dim = 512
n_iter = 50

model = Dense(model_dim, 1)
rng_key, sub_key = random.split(rng_key)
init_params = model.init_fun(sub_key)
opt_init, opt_update, get_params = optimizers.adam(0.001)
opt_state = opt_init(init_params)

@jit
def update(i, current_opt_state, current_batch):
    N = current_batch[0].shape[0]
    K = accumulation_size
    num_gradients = N//K
    accumulation_batch = (current_batch[ib][0:K] for ib in range(len(current_batch)))
    value, grads = jax.value_and_grad(loss_func)(get_params(current_opt_state), accumulation_batch)
    value = value / num_gradients
    grads = divide_pytree(grads, num_gradients)
    for k in range(K,N,K):
        accumulation_batch = (current_batch[ib][k:k+K] for ib in range(len(current_batch)))
        new_value, new_grads = jax.value_and_grad(loss_func)(get_params(current_opt_state), accumulation_batch)
        value = value + (new_value / num_gradients)
        grads = add_pytrees(grads, divide_pytree(new_grads, num_gradients))
    return opt_update(i, grads, current_opt_state), value

def loss_func(current_params, current_batch):
    inputs, labels = current_batch
    predictions = model.apply_fun(current_params, inputs)
    loss = jnp.square(labels-predictions).sum()
    return loss

for i in range(n_iter):
    rng_key, sub_key1, sub_key2 = random.split(rng_key, 3)
    inputs = jax.random.uniform(sub_key1, (batch_size, model_dim))
    labels = jax.random.uniform(sub_key2, (batch_size, 1)) > 0.5
    batch = inputs, labels
    opt_state, batch_loss = update(i, opt_state, batch)
    print(i, batch_loss)

I have doubts about the divide_pytree and add_pytrees. Does it actually modify the current batch gradient or am I missing something ? Moreover, do you see any speed issue with this code ? In particular, should I use the jax.lax.fori_loop in place of the traditional python for loop ?

Related links:


Solution

  • Regarding the pytree computations: as written your functions are returning the input unmodified. The better approach for this is to use jax.tree_util.tree_map; for example:

    from jax.tree_util import tree_map
    
    def divide_pytree(pytree, div):
      return tree_map(lambda pt: pt / div, pytree)
    
    def add_pytrees(pytree1, pytree2):
      return tree_map(lambda pt1, pt2: pt1 + pt2, pytree1, pytree2)
    

    Regarding performance: anything in the for loop will be flattened when JIT-compiled, with one repeated copy of all XLA instructions per iteration of the loop. If you have 5 iterations, that's not really an issue. If you have 5000, that would significantly slow down compilation times (because XLA needs to analyze & optimize 5000 explicit copies of the instructions in the loop).

    fori_loop can help, but does not lead to optimal code, particularly when running on CPU and GPU.

    Better would be to use broadcasted or vmapped operations where possible to express the logic of the loops without explicit looping.