Search code examples
pythonjax

Efficient copying of an ensemble in JAX


I have an ensemble of models and want to assign the same parameters to each of the models. Both the models' parameters as well as the new parameters have the same underlying structure. Currently I use the following approach that uses a for-loop.

import jax
import jax.numpy as jnp

model1 = [
    [jnp.asarray([1]), jnp.asarray([2, 3])],
    [jnp.asarray([4]), jnp.asarray([5, 6])],
]

model2 = [
    [jnp.asarray([2]), jnp.asarray([3, 4])],
    [jnp.asarray([5]), jnp.asarray([6, 7])],
]

models = [model1, model2]

params = [
    [jnp.asarray([3]), jnp.asarray([4, 5])],
    [jnp.asarray([6]), jnp.asarray([7, 8])],
]

models = [jax.tree_map(jnp.copy, params) for _ in range(len(models))]

Is there a more efficient way in JAX to assign the parameters from params to each model in models?


Solution

  • Since JAX arrays are immutable, there's no need to copy the parameter arrays, and you could achieve the same result like this:

    models = len(models) * [params]