I have an ensemble of models and want to assign the same parameters to each of the models. Both the models' parameters as well as the new parameters have the same underlying structure. Currently I use the following approach that uses a for-loop.
import jax
import jax.numpy as jnp
model1 = [
[jnp.asarray([1]), jnp.asarray([2, 3])],
[jnp.asarray([4]), jnp.asarray([5, 6])],
]
model2 = [
[jnp.asarray([2]), jnp.asarray([3, 4])],
[jnp.asarray([5]), jnp.asarray([6, 7])],
]
models = [model1, model2]
params = [
[jnp.asarray([3]), jnp.asarray([4, 5])],
[jnp.asarray([6]), jnp.asarray([7, 8])],
]
models = [jax.tree_map(jnp.copy, params) for _ in range(len(models))]
Is there a more efficient way in JAX to assign the parameters from params
to each model in models
?
Since JAX arrays are immutable, there's no need to copy the parameter arrays, and you could achieve the same result like this:
models = len(models) * [params]