I'm trying to compute the gradient of a lambda function that involves other gradients of functions, but the computation is hanging and I do not understand why.
In particular, the code below successfully computes f_next
, but not its derivative (penultimate and last line).
Any help would be appreciated
import jax
import jax.numpy as jnp
# Model parameters
γ = 1.5
k = 0.1
μY = 0.03
σ = 0.03
λ = 0.1
ωb = μY/λ
# PDE params.
σω = σ
dt =0.01
IC = lambda ω: jnp.exp(-(1-γ)*ω)
f = [IC]
f_x= jax.grad(f[0]) #first derivative
f_xx= jax.grad(jax.grad(f[0]))#second derivative
f_old = f[0]
f_next = lambda ω: f_old(ω) + 100*dt * (
(0.5*σω**2)*f_xx(ω) - λ*(ω-ωb)*f_x(ω)
- k*f_old(ω) + jnp.exp(-(1-γ)*ω))
print(f_next(0.))
f.append(f_next)
f_x= jax.grad(f[1]) #first derivative
print(f_x(0.))