1

Does Jax support taking the derivate w.r.t. an index of a vector-valued variable? Consider this example (where a is a vector/array):

def test_func(a):
  return a[0]**a[1]

I can pass in the argument number into grad(..), but I cannot seem to pass the index of a vector-valued argument like in the example above. I tried passing a tuple of tuples, i.e.,

grad(test_func, argnums=((0,),))

but that does not work.

user654123
  • 465
  • 6
  • 19

1 Answers1

2

There's no built-in transform that can take gradients with respect to certain elements of arrays, but you can straightforwardly do this via a wrapper function that splits the array into individual elements; for example:

import jax
import jax.numpy as jnp

def test_func(a):
  return a[0]**a[1]

a = jnp.array([1.0, 2.0])
fgrad = jax.grad(lambda *args: test_func(jnp.array(args)), argnums=0)
print(fgrad(*a))
# 2.0

If you want to take a gradient with respect to all the inputs individually (returning a vector of gradients with respect to each entry), you can use jax.jacobian:

print(jax.jacobian(test_func)(a))
# [2. 0.]
jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • Very nice, this solves my problem directly. I hadn't thought about wrapping it in a lambda that takes a list and calls the function with an array. As you point out, using `jax.jacobian` is great when you need the differential w.r.t. all parameters (which I might need in the future). – user654123 Nov 23 '21 at 10:06