I have a joint cumulative density function defined in python as a function of a jax array and returning a single value. Something like:
def cumulative(inputs: array) -> float:
...
To have the gradient, I know I can just do grad(cumulative)
, but that is only giving me the first-order partial derivatives of cumulative with respect to the input variables.
Instead, what I would like to do is to compute is this, assuming F is my function and f the joint probability density function:
The order of the partial derivation doesn't matter.
So, I have several questions:
- how to compute this efficiently in Jax? I assume I cannot just call grad n times
- once the resulting function is computed, will the resulting function have a higher call complexity than the original function (is it increased by O(n), or is it constant, or something else)?
- alternatively, how can I compute a single partial derivative with respect to only one of the variable of the input array, as opposed to the entire array? (And I will just repeat this n times, once per variable)