1

I would like to trace the grads through the self.put_variable. Is there anyway to make that possible? Or another way to update the param supplied to the module that is traced?

import jax 
from jax import numpy as jnp 
from jax import grad,random,jit,vmap 
import flax 
from flax import linen as nn 


class network(nn.Module):
    input_size : int 
    output_size : int 
    @nn.compact
    def __call__(self,x):
        W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
        b = self.param('b',nn.initializers.normal(),(self.output_size,))

      
        self.put_variable("params","b",(x@W+b).reshape(5,))  
    
        return jnp.sum(x+b)


if __name__ == "__main__":
    key = random.PRNGKey(0)
    key_x,key_param,key = random.split(key,3)
    x = random.normal(key_x,(1,5))

    module = network(5,5)
    param = module.init(key_param,x)
    print(param)
    #x,param = module.apply(param,x,mutable=["params"])
    #print(param)
    print(grad(module.apply,has_aux=True)(param,x,mutable=["params"]))

my output grads are :

FrozenDict({
    params: {
        W: DeviceArray([[0., 0., 0., 0., 0.],
                     [0., 0., 0., 0., 0.],
                     [0., 0., 0., 0., 0.],
                     [0., 0., 0., 0., 0.],
                     [0., 0., 0., 0., 0.]], dtype=float32),
        b: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
    },

What shows that it doesnt trace the grads through the self.variable_put method, as grads to W are all zero, while b clearly relies upon W.

hal9000
  • 222
  • 2
  • 12

2 Answers2

0

The output of your model is jnp.sum(x + b), which has no dependence on W, which in turn implies that the gradient with respect to W should be zero. With this in mind, the output you show above looks to be correct.

Edit: It sounds like you're expecting the result of x@W+b that you used in your variable to be reflected in the value of b used in the return statement; perhaps you want something like this?

    def __call__(self,x):
        W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
        b = self.param('b',nn.initializers.normal(),(self.output_size,))

        b = x@W+b
        self.put_variable("params","b",b.reshape(5,)) 
    
        return jnp.sum(x+b)

That said, it's unclear to me from the question what your ultimate goal is, and given that you're asking about such an uncommon construct, I suspect this may be an XY problem. Perhaps you can edit your question to say more about what you're trying to accomplish.

jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • self.put_variable("params","b",(x@W+b).reshape(5,)) updates b and relies upon W and b, so jnp.sum(x + b) relies upon W, x and b. As b is now b(x,W,b_0) – hal9000 Jun 21 '22 at 20:02
  • if the parameters are in a pytree I can do the same thing but it tracks the grads, but I don't want to do it in plain jax as I want to use the clean parameter management of flax. – hal9000 Jun 21 '22 at 20:04
  • I see what you mean b is still bound to the previous b value, instead of the new one. – hal9000 Jun 22 '22 at 09:44
  • My ultimate goal is to make self modifying networks. I'm currently trying to recreate the results from "A Modern Self-Referential Weight Matrix That Learns to Modify Itself". But for that I need to be able to do new_param,y = f(param,x) what I struggled with as I'm quite new to flax. Where the new_param are depended upon the old param and x so I can trace grads back to the old param – hal9000 Jun 22 '22 at 18:11
  • I solved the problem in the end (above answer), thanks for your help either way. :) – hal9000 Jun 22 '22 at 18:12
0

Just like @jakevdp noted the test above is incorrect as b is still tied to the previous b.
https://github.com/google/flax/discussions/2215 said that self.put_variable is traced.

Testing if that is actually the case using the code below:

import jax 
from jax import numpy as jnp 
from jax import grad,random,jit,vmap 
import flax 
from flax import linen as nn 

class network(nn.Module):
    input_size : int 
    output_size : int 
    @nn.compact
    def __call__(self,x):
        W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
        b = self.param('b',nn.initializers.normal(),(self.output_size,))

        b = x@W+b #update the b variable else it is still tied to the previous one.
        self.put_variable("params","b",(b).reshape(5,))  
     
        return jnp.sum(x+b)

def test_update(param,x):
    _, param = module.apply(param,x,mutable=["params"])
    return jnp.sum(param["params"]["b"]+x),param 

if __name__ == "__main__":
    key = random.PRNGKey(0)
    key_x,key_param,key = random.split(key,3)
    x = random.normal(key_x,(1,5))

    module = network(5,5)
    param = module.init(key_param,x)
    print(param)

    print(grad(test_update,has_aux=True)(param,x))

output:

FrozenDict({
    params: {
        W: DeviceArray([[ 0.01678762,  0.00234134,  0.00906202,  0.00027337,
                       0.00599653],
                     [-0.00729604, -0.00417799,  0.00172333, -0.00566238,
                       0.0097266 ],
                     [ 0.00378883, -0.00901531,  0.01898266, -0.01733185,
                      -0.00616944],
                     [-0.00806503,  0.00409351,  0.0179838 , -0.00238476,
                       0.00252594],
                     [ 0.00398197,  0.00030245, -0.00640218, -0.00145424,
                       0.00956188]], dtype=float32),
        b: DeviceArray([-0.00905032, -0.00574646,  0.01621638, -0.01165553,
                     -0.0285466 ], dtype=float32),
    },
})
(FrozenDict({
    params: {
        W: DeviceArray([[-1.1489547 , -1.1489547 , -1.1489547 , -1.1489547 ,
                      -1.1489547 ],
                     [-2.0069852 , -2.0069852 , -2.0069852 , -2.0069852 ,
                      -2.0069852 ],
                     [ 0.98777294,  0.98777294,  0.98777294,  0.98777294,
                       0.98777294],
                     [ 0.9311977 ,  0.9311977 ,  0.9311977 ,  0.9311977 ,
                       0.9311977 ],
                     [-0.2883922 , -0.2883922 , -0.2883922 , -0.2883922 ,
                      -0.2883922 ]], dtype=float32),
        b: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
    },
}), FrozenDict({
    params: {
        W: DeviceArray([[ 0.01678762,  0.00234134,  0.00906202,  0.00027337,
                       0.00599653],
                     [-0.00729604, -0.00417799,  0.00172333, -0.00566238,
                       0.0097266 ],
                     [ 0.00378883, -0.00901531,  0.01898266, -0.01733185,
                      -0.00616944],
                     [-0.00806503,  0.00409351,  0.0179838 , -0.00238476,
                       0.00252594],
                     [ 0.00398197,  0.00030245, -0.00640218, -0.00145424,
                       0.00956188]], dtype=float32),
        b: DeviceArray([-0.01861148, -0.00523183,  0.03968921, -0.01952654,
                     -0.06145691], dtype=float32),
    },
}))

The first FrozenDict is the original parameters.
The second FrozenDict is the grads, clearly being traced through self.put_variable.
The last FrozenDict is the parameters, where we can see that b is correctly updated.

hal9000
  • 222
  • 2
  • 12
  • I wouldn't write this as an answer. Instead of it edit your question and mark this with update / fix... – droebi Jun 23 '22 at 15:22