Search code examples
pythonjaxflax

Can you update parameters of a module from inside the nn.compact of that module? (self modifying networks)


I'm quite new to flax and I was wondering what the correct way is to get this behavior:

param = f.init(key,x) 
new_param, y = f.apply(param,x) 

Where f is a nn.module instance.
Where f might go through multiple operations to get new_param and that those operations might rely on the intermediate param to produce their output.

So basically, is there a way I can access and update the parameters supplied to an instance of nn.module from within the __call__, while not losing the functional property so it can all be wrapped with the grad function transform.


Solution

  • You can treat your parameter as mutable var. Just reference to https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.BatchNorm.html

    @nn.compact
    def __call__(self, x):
        some_params = self.variable('mutable_params', 'some_params', init_fn)
        # 'mutable_params' is the variable collection name
        # at the same "level" as 'params'
    
    vars_init = model.init(key, x)
    # vars_init = {'params': nested_dict_for_params, 'mutable_params': nested_dict_for_mutable_params}
    y, mutated_vars = model.apply(vars_init, x, mutable=['mutable_params'])
    vars_new = vars_init | mutated_vars # I'm not sure frozendict support | op
    # equiv to vars_new = {'params': vars_init['params'], 'mutable_params': mutated_vars['mutable_params']}