Search code examples
pythonjax

Parallelize inference of ensemble


I used the this tutorial from JAX to create an ensemble of networks. Currently I compute the loss of each network in a for-loop which I would like to avoid:

for params in ensemble_params:
    loss = mse_loss(params, inputs=x, targets=y)

def mse_loss(params, inputs, targets):
    preds = batched_predict(params, inputs)
    loss = jnp.mean((targets - preds) ** 2)
    return loss

Here ensemble_params is a list of pytrees (lists of tuples holding JAX parameter arrays). The parameter structure of each network is the same.

I tried to get rid of the for-loop by applying jax.vmap:

ensemble_loss = jax.vmap(fun=mse_loss, in_axes=(0, None, None))

However, I keep getting the following error message which I do not understand.

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * most axes (8 of them) had size 3, e.g. axis 0 of argument params[0][0][0] of type float32[3,2];
  * some axes (8 of them) had size 4, e.g. axis 0 of argument params[0][1][0] of type float32[4,3]

Here is a minimal reproducible example:

import jax
from jax import Array
from jax import random
import jax.numpy as jnp

def layer_params(dim_in: int, dim_out: int, key: Array) -> tuple[Array]:
    w_key, b_key = random.split(key=key)
    weights = random.normal(key=w_key, shape=(dim_out, dim_in))
    biases = random.normal(key=w_key, shape=(dim_out,))
    return weights, biases

def init_params(layer_dims: list[int], key: Array) -> list[tuple[Array]]:
    keys = random.split(key=key, num=len(layer_dims))
    params = []
    for dim_in, dim_out, key in zip(layer_dims[:-1], layer_dims[1:], keys):
        params.append(layer_params(dim_in=dim_in, dim_out=dim_out, key=key))
    return params

def init_ensemble(key: Array, num_models: int, layer_dims: list[int]) -> list:
    keys = random.split(key=key, num=num_models)
    models = [init_params(layer_dims=layer_dims, key=key) for key in keys]
    return models

def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits

batched_predict = jax.vmap(predict, in_axes=(None, 0))

def mse_loss(params, inputs, targets):
    preds = batched_predict(params, inputs)
    loss = jnp.mean((targets - preds) ** 2)
    return loss

if __name__ == "__main__":

    num_models = 4
    dim_in = 2
    dim_out = 4
    layer_dims = [dim_in, 3, dim_out]
    batch_size = 2

    key = random.PRNGKey(seed=1)
    key, subkey = random.split(key)
    ensemble_params = init_ensemble(key=subkey, num_models=num_models, layer_dims=layer_dims)

    key_x, key_y = random.split(key)
    x = random.normal(key=key_x, shape=(batch_size, dim_in))
    y = random.normal(key=key_y, shape=(batch_size, dim_out))

    for params in ensemble_params:
        loss = mse_loss(params, inputs=x, targets=y)
        print(f"{loss = }")

    ensemble_loss = jax.vmap(fun=mse_loss, in_axes=(0, None, None))
    losses = ensemble_loss(ensemble_params, x, y)
    print(f"{losses = }")  # Same losses expected as above.

Solution

  • The main issue here is that vmap maps over arrays, not over lists.

    You are passing a list of parameter objects, expecting vmap to map over the elements of that list. However, the semantics of vmap are that it maps over the first axis of each tree leaf in the argument, and the leaves in your argument differ in their leading axis.

    To fix this, instead of passing a list of parameter objects containing unbatched arrays, you need to pass a single parameter object containing batched arrays; in other words you need a struct-of-arrays pattern rather than a list-of-structs pattern.

    In your case, you can create your batched ensemble parameters this way:

    ensemble_params = jax.tree_map(lambda *args: jnp.stack(args), *ensemble_params)
    

    If you pass this to the ensemble_loss function, you get the expected output:

    losses = Array([3.762451 , 4.39846  , 4.1425314, 6.045669 ], dtype=float32)