For example you set up a module and that has params. But if you want do regularize something in a loss what is the pattern?
import jax.numpy as jnp
import jax
def loss(params, x, y):
l = jnp.sum((y - mlp.apply(params, x)) ** 2)
w = hk.get_params(params, 'w') # does not work like this
l += jnp.sum(w ** w)
return l
There is some pattern missing in the examples.
params
is essentially a read-only dictionary, so you can get the value of a parameter by treating it as a dictionary:
print(params['w'])
If you want to update the parameters, you cannot do it in-place, but have to first convert it to a mutable dictionary:
params_mutable = hk.data_structures.to_mutable_dict(params)
params_mutable['w'] = 3.14
params_new = hk.data_structures.to_immutable_dict(params_mutable)