Lets suppose I have some function which returns a sum of inputs.
@jit
def some_func(a,r1,r2):
return a + r1 + r2
Now I would like to loop over different values of r1
and r2
, save the result and add it to a counter. This is what I mean:
a = 0
r1 = jnp.arange(0,3)
r2 = jnp.arange(0,3)
s = 0
for i in range(len(r1)):
for j in range(len(r2)):
s+= some_func(a, r1[i], r2[j])
print(s)
DeviceArray(18, dtype=int32)
My question is, how do I do this with jax.vmap
to avoid writing the for
loops? I have something like this so far:
vmap(some_func, in_axes=(None, 0,0), out_axes=0)(jnp.arange(0,3), jnp.arange(0,3))
but this gives me the following error:
ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification (None, 0, 0) for value tree PyTreeDef((*, *)).
I have a feeling that the error is in in_axes
but I am not sure how to get vmap
to pick a value for r1
loop over r2
and then do the same for all r1
whilst saving intermediate results.
Any help is appreciated.