I'm quite new to flax and I was wondering what the correct way is to get this behavior:
param = f.init(key,x)
new_param, y = f.apply(param,x)
Where f is a nn.module instance.
Where f might go through multiple operations to get new_param and that those operations might rely on the intermediate param to produce their output.
So basically, is there a way I can access and update the parameters supplied to an instance of nn.module from within the __call__
,
while not losing the functional property so it can all be wrapped with the grad function transform.