0

I'm quite new to flax and I was wondering what the correct way is to get this behavior:

param = f.init(key,x) 
new_param, y = f.apply(param,x) 

Where f is a nn.module instance.
Where f might go through multiple operations to get new_param and that those operations might rely on the intermediate param to produce their output.

So basically, is there a way I can access and update the parameters supplied to an instance of nn.module from within the __call__, while not losing the functional property so it can all be wrapped with the grad function transform.

hal9000
  • 222
  • 2
  • 12
  • found this discussion where they show how to get access to parameters within a module https://github.com/google/flax/discussions/1846. So all that is left is figuring out how to update those – hal9000 Jun 17 '22 at 14:31

1 Answers1

0

You can treat your parameter as mutable var. Just reference to https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.BatchNorm.html

@nn.compact
def __call__(self, x):
    some_params = self.variable('mutable_params', 'some_params', init_fn)
    # 'mutable_params' is the variable collection name
    # at the same "level" as 'params'
vars_init = model.init(key, x)
# vars_init = {'params': nested_dict_for_params, 'mutable_params': nested_dict_for_mutable_params}
y, mutated_vars = model.apply(vars_init, x, mutable=['mutable_params'])
vars_new = vars_init | mutated_vars # I'm not sure frozendict support | op
# equiv to vars_new = {'params': vars_init['params'], 'mutable_params': mutated_vars['mutable_params']}
YouJiacheng
  • 449
  • 3
  • 11
  • return W@x+b TypeError: unsupported operand type(s) for @: 'Variable' and 'DeviceArray'. Variables seem to be treated different from parameters, do you know if there is maybe a way around it? – hal9000 Jun 21 '22 at 17:47
  • I actually tried your trick of using mutable=[] in the apply method but now with params and then using self.put_variable("params","some_param_name",(x@W+b)) and using has_aux=True in grad. This then allows me to update the chosen parameter but the problem is that the grads get stopped and aren't traced through the self.put_variable method.. what is kinda sad. – hal9000 Jun 21 '22 at 18:17
  • https://github.com/google/flax/discussions/2215 (my current intermediate solution is here) – hal9000 Jun 21 '22 at 18:31
  • 1
    @hal9000 Variable.value is the array – YouJiacheng Jun 23 '22 at 08:40
  • https://stackoverflow.com/questions/72705707/is-there-a-way-to-trace-grads-through-self-put-variable-method-in-flax you can also use self.param and use self.put_variables to update – hal9000 Jun 23 '22 at 20:38