Search code examples
jaxhaiku

How does one get a parameter from a params (pytree) in haiku? (jax framework)


For example you set up a module and that has params. But if you want do regularize something in a loss what is the pattern?

import jax.numpy as jnp
import jax
def loss(params, x, y):
   l = jnp.sum((y - mlp.apply(params, x)) ** 2)
   w = hk.get_params(params, 'w') # does not work like this
   l += jnp.sum(w ** w)
   return l

There is some pattern missing in the examples.


Solution

  • params is essentially a read-only dictionary, so you can get the value of a parameter by treating it as a dictionary:

    print(params['w'])
    

    If you want to update the parameters, you cannot do it in-place, but have to first convert it to a mutable dictionary:

    params_mutable = hk.data_structures.to_mutable_dict(params)
    params_mutable['w'] = 3.14
    params_new = hk.data_structures.to_immutable_dict(params_mutable)