This may me a very simple thing, but I was wondering how to perform mapping in the following example.
Suppose we have a function that we want to evaluate derivative with respect to xt
, yt
and zt
, but it also takes additional parameters xs
, ys
and zs
.
import jax.numpy as jnp
from jax import grad, vmap
def fn(xt, yt, zt, xs, ys, zs):
return jnp.sqrt((xt - xs) ** 2 + (yt - ys) ** 2 + (zt - zs) ** 2)
Now, let us define the input data:
xt = jnp.array([1., 2., 3., 4.])
yt = jnp.array([1., 2., 3., 4.])
zt = jnp.array([1., 2., 3., 4.])
xs = jnp.array([1., 2., 3.])
ys = jnp.array([3., 3., 3.])
zs = jnp.array([1., 1., 1.])
In order to evaluate gradient for each pair of data points in xt
, yt
and zt
, I have to do the following:
fn_prime = vmap(grad(fn, argnums=(0, 1, 2)), in_axes=(None, None, None, 0, 0, 0))
a = []
for _xt in xt:
for _yt in yt:
for _zt in zt:
a.append(fn_prime(_xt, _yt, _zt, xs, ys, zs))
and it results in a list of tuples.
Once the list is converted to a jnp.array
, it is of the following shape:
a = jnp.array(a)
print(f`shape = {a.shape}')
shape = (64, 3, 3)
My question is: Is there a way to avoid this for loop and evaluate all gradients in the same sweep?