I'm quite new to flax and I was wondering what the correct way is to get this behavior:
param = f.init(key,x)
new_param, y = f.apply(param,x)
Where f is a nn.module instance.
Where f might go through multiple operations to get new_param and that those operations might rely on the intermediate param to produce their output.
So basically, is there a way I can access and update the parameters supplied to an instance of nn.module from within the __call__
,
while not losing the functional property so it can all be wrapped with the grad function transform.
You can treat your parameter as mutable var. Just reference to https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.BatchNorm.html
@nn.compact
def __call__(self, x):
some_params = self.variable('mutable_params', 'some_params', init_fn)
# 'mutable_params' is the variable collection name
# at the same "level" as 'params'
vars_init = model.init(key, x)
# vars_init = {'params': nested_dict_for_params, 'mutable_params': nested_dict_for_mutable_params}
y, mutated_vars = model.apply(vars_init, x, mutable=['mutable_params'])
vars_new = vars_init | mutated_vars # I'm not sure frozendict support | op
# equiv to vars_new = {'params': vars_init['params'], 'mutable_params': mutated_vars['mutable_params']}