1

Let f: R -> R be an infinitely differentiable function. What is the computational complexity of calculating the first n derivatives of f in Jax? Naive chain rule would suggest that each multiplication gives a factor of 2 increase, hence the nth derivative would require at least 2^n more operations. I imagine though that clever manipulation of formal series would reduce the number of required calculations and eliminate duplications, esspecially if the derivaives are Jax jitted? Is there a different between the Jax, Tensorflow and Torch implementations?

https://openreview.net/forum?id=SkxEF3FNPH discusses this topic, but doesn t provide a computational complexity.

Dan Leonte
  • 27
  • 4

2 Answers2

1

What is the computational complexity of calculating the first n derivatives of f in Jax?

There's not much you can say in general about computational complexity of Nth derivatives. For example, with a function like jnp.sin, the Nth derivative is O[1], oscillating between negative and positive sin and cos calls as N grows. For an order-k polynomial, the Nth derivative is O[0] for N > k. Other functions may have complexity that is linear or polynomial or even exponential with N depending on the operations they contain.

I imagine though that clever manipulation of formal series would reduce the number of required calculations and eliminate duplications, esspecially if the derivaives are Jax jitted

You imagine correctly! One implementation of this idea is the jax.experimental.jet module, which is an experimental transform designed for computing higher-order derivatives efficiently and accurately. It doesn't cover all JAX functions, but it may be complete enough to do what you have in mind.

jakevdp
  • 77,104
  • 11
  • 125
  • 160
0

If L is the complexity of evaluating the scalar function f, then L*(n+1)^2 is an upper bound for the complexity of finding the first n derivatives of f as coefficients of a truncated Taylor series.

The general idea is that each elementary function can be implemented for truncated Taylor series in the equivalent of one or two truncated series multiplications.

Lutz Lehmann
  • 25,219
  • 2
  • 22
  • 51