Does Jax support taking the derivate w.r.t. an index of a vector-valued variable? Consider this example (where a
is a vector/array):
def test_func(a):
return a[0]**a[1]
I can pass in the argument number into grad(..)
, but I cannot seem to pass the index of a vector-valued argument like in the example above. I tried passing a tuple of tuples, i.e.,
grad(test_func, argnums=((0,),))
but that does not work.