1

I have a joint cumulative density function defined in python as a function of a jax array and returning a single value. Something like:

def cumulative(inputs: array) -> float:
    ...

To have the gradient, I know I can just do grad(cumulative), but that is only giving me the first-order partial derivatives of cumulative with respect to the input variables. Instead, what I would like to do is to compute is this, assuming F is my function and f the joint probability density function:

formula

The order of the partial derivation doesn't matter.

So, I have several questions:

  • how to compute this efficiently in Jax? I assume I cannot just call grad n times
  • once the resulting function is computed, will the resulting function have a higher call complexity than the original function (is it increased by O(n), or is it constant, or something else)?
  • alternatively, how can I compute a single partial derivative with respect to only one of the variable of the input array, as opposed to the entire array? (And I will just repeat this n times, once per variable)
jakevdp
  • 77,104
  • 11
  • 125
  • 160
Ben
  • 35
  • 1
  • 6

1 Answers1

0

JAX generally treats gradients as being with respect to individual arguments, not elements within arguments. Within this context, one built-in function that is similar to what you want to do (but not exactly the same) is jax.hessian, which computes the hessian matrix of second derivatives; for example:

import jax
import jax.numpy as jnp

def f(x):
  return jnp.prod(x ** 2)

x = jnp.arange(1.0, 4.0)
print(jax.hessian(f)(x))
# [[72. 72. 48.]
#  [72. 18. 24.]
#  [48. 24.  8.]]

For higher-order derivatives with respect to individual elements of the array, I think you'll have to manually nest the gradients. You could do so with a helper function that looks something like this:

def grad_all(f):
  def gradfun(x):
    args = tuple(x)
    f_args = lambda *args: f(jnp.array(args))
    for i in range(len(args)):
      f_args = jax.grad(f_args, argnums=i)
    return f_args(*args)
  return gradfun

print(grad_all(f)(x))
# 48.0
jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • Your function seems sound and run, but I'm getting negative values which I believe shouldn't be possible. Could it be that your function is correct but that jax somehow manages to get negative values where it shouldn't? – Ben Nov 16 '21 at 18:48
  • Depending on the inputs, floating point rounding errors could cause JAX to produce negative values when none would be expected. – jakevdp Nov 16 '21 at 19:02
  • I see, is there a simple way to prevent this? Can I just take clip the value of the gradient at each step of the loop for instance? – Ben Nov 16 '21 at 19:07
  • Maybe? Hard to say without more information. But taking gradients of clipped functions can run into other issues with boundary conditions. – jakevdp Nov 16 '21 at 19:10