I used the this tutorial from JAX to create an ensemble of networks. Currently I compute the loss of each network in a for-loop which I would like to avoid:
for params in ensemble_params:
loss = mse_loss(params, inputs=x, targets=y)
def mse_loss(params, inputs, targets):
preds = batched_predict(params, inputs)
loss = jnp.mean((targets - preds) ** 2)
return loss
Here ensemble_params
is a list of pytrees (lists of tuples holding JAX parameter arrays). The parameter structure of each network is the same.
I tried to get rid of the for-loop by applying jax.vmap
:
ensemble_loss = jax.vmap(fun=mse_loss, in_axes=(0, None, None))
However, I keep getting the following error message which I do not understand.
ValueError: vmap got inconsistent sizes for array axes to be mapped:
* most axes (8 of them) had size 3, e.g. axis 0 of argument params[0][0][0] of type float32[3,2];
* some axes (8 of them) had size 4, e.g. axis 0 of argument params[0][1][0] of type float32[4,3]
Here is a minimal reproducible example:
import jax
from jax import Array
from jax import random
import jax.numpy as jnp
def layer_params(dim_in: int, dim_out: int, key: Array) -> tuple[Array]:
w_key, b_key = random.split(key=key)
weights = random.normal(key=w_key, shape=(dim_out, dim_in))
biases = random.normal(key=w_key, shape=(dim_out,))
return weights, biases
def init_params(layer_dims: list[int], key: Array) -> list[tuple[Array]]:
keys = random.split(key=key, num=len(layer_dims))
params = []
for dim_in, dim_out, key in zip(layer_dims[:-1], layer_dims[1:], keys):
params.append(layer_params(dim_in=dim_in, dim_out=dim_out, key=key))
return params
def init_ensemble(key: Array, num_models: int, layer_dims: list[int]) -> list:
keys = random.split(key=key, num=num_models)
models = [init_params(layer_dims=layer_dims, key=key) for key in keys]
return models
def relu(x):
return jnp.maximum(0, x)
def predict(params, image):
activations = image
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = relu(outputs)
final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits
batched_predict = jax.vmap(predict, in_axes=(None, 0))
def mse_loss(params, inputs, targets):
preds = batched_predict(params, inputs)
loss = jnp.mean((targets - preds) ** 2)
return loss
if __name__ == "__main__":
num_models = 4
dim_in = 2
dim_out = 4
layer_dims = [dim_in, 3, dim_out]
batch_size = 2
key = random.PRNGKey(seed=1)
key, subkey = random.split(key)
ensemble_params = init_ensemble(key=subkey, num_models=num_models, layer_dims=layer_dims)
key_x, key_y = random.split(key)
x = random.normal(key=key_x, shape=(batch_size, dim_in))
y = random.normal(key=key_y, shape=(batch_size, dim_out))
for params in ensemble_params:
loss = mse_loss(params, inputs=x, targets=y)
print(f"{loss = }")
ensemble_loss = jax.vmap(fun=mse_loss, in_axes=(0, None, None))
losses = ensemble_loss(ensemble_params, x, y)
print(f"{losses = }") # Same losses expected as above.
The main issue here is that vmap
maps over arrays, not over lists.
You are passing a list of parameter objects, expecting vmap
to map over the elements of that list. However, the semantics of vmap
are that it maps over the first axis of each tree leaf in the argument, and the leaves in your argument differ in their leading axis.
To fix this, instead of passing a list of parameter objects containing unbatched arrays, you need to pass a single parameter object containing batched arrays; in other words you need a struct-of-arrays pattern rather than a list-of-structs pattern.
In your case, you can create your batched ensemble parameters this way:
ensemble_params = jax.tree_map(lambda *args: jnp.stack(args), *ensemble_params)
If you pass this to the ensemble_loss
function, you get the expected output:
losses = Array([3.762451 , 4.39846 , 4.1425314, 6.045669 ], dtype=float32)