I would like to trace the grads through the self.put_variable. Is there anyway to make that possible? Or another way to update the param supplied to the module that is traced?
import jax
from jax import numpy as jnp
from jax import grad,random,jit,vmap
import flax
from flax import linen as nn
class network(nn.Module):
input_size : int
output_size : int
@nn.compact
def __call__(self,x):
W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
b = self.param('b',nn.initializers.normal(),(self.output_size,))
self.put_variable("params","b",(x@W+b).reshape(5,))
return jnp.sum(x+b)
if __name__ == "__main__":
key = random.PRNGKey(0)
key_x,key_param,key = random.split(key,3)
x = random.normal(key_x,(1,5))
module = network(5,5)
param = module.init(key_param,x)
print(param)
#x,param = module.apply(param,x,mutable=["params"])
#print(param)
print(grad(module.apply,has_aux=True)(param,x,mutable=["params"]))
my output grads are :
FrozenDict({
params: {
W: DeviceArray([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]], dtype=float32),
b: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
},
What shows that it doesnt trace the grads through the self.variable_put method, as grads to W are all zero, while b clearly relies upon W.