Let f: R -> R
be an infinitely differentiable function. What is the computational complexity of calculating the first n
derivatives of f
in Jax? Naive chain rule would suggest that each multiplication gives a factor of 2 increase, hence the nth derivative would require at least 2^n
more operations. I imagine though that clever manipulation of formal series would reduce the number of required calculations and eliminate duplications, esspecially if the derivaives are Jax jitted? Is there a different between the Jax, Tensorflow and Torch implementations?
https://openreview.net/forum?id=SkxEF3FNPH discusses this topic, but doesn t provide a computational complexity.