1

I am trying to learn how to find the Jacobian of a vector-valued ODE function using JAX. I am using the examples at https://implicit-layers-tutorial.org/implicit_functions/ That page implements its own ODE integrator and associated custom forward-mode and reverse-mode Jacobian functions. I am trying to reproduce that using the official jax odeint and diffrax libraries, but both of these primarily use reverse-mode Vector Jacobian Product (VJP) instead of the forward-mode Jacobian Vector Product (JVP) for which example code is available on that page.

Here is a code snippet that I adapted from that page:

import matplotlib.pyplot as plt

from jax.config import config
config.update("jax_enable_x64", True)

import jax.numpy as jnp
from jax import jit, jvp, vjp
from jax.experimental.ode import odeint

from diffrax import diffeqsolve, ODETerm, PIDController, SaveAt, Dopri5, NoAdjoint

# returns time derivatives of each of our 3 state variables (vector-valued function)
def f(state, t, args):
    x, y, z = state
    rho, sigma, beta = args 
    return jnp.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])

# convenience function that calls jax-odeint given input initial conditions and parameters (this is the function that we want Jacobian/sensitivities of)
def evolve(y0, rho, sigma, beta): 
    return odeint(f, y0, tarr, (rho, sigma, beta))


# set up initial conditions, timespan for integration, and fiducial parameter values
y0 = jnp.array([5., 5., 5.])
tarr = jnp.linspace(0, 1., 1000)
rho = 28.
sigma = 10.
beta = 8/3. 


# first just make sure evolve() works 
ys = evolve(y0, rho, sigma, beta)

fig, ax = plt.subplots(1,figsize=(6,4),dpi=150,subplot_kw={'projection':'3d'})   
ax.plot(ys.T[0],ys.T[1],ys.T[2],'b-',lw=0.5)

# now try to take reverse-mode vector-jacobian product (VJP) since forward-mode JVP is not defined for jax-odeint
vjp_ys, vjp_evolve = vjp(evolve,y0,rho,sigma,beta)

# vjp_ys and ys are equal -- they are the solution time series of the 3 components (state variables) of y 
print(jnp.array_equal(ys,vjp_ys))

# define some perturbation in y0 and parameters 
delta_y0 = jnp.array([0., 0., 0.])
delta_rho = 0.
delta_sigma = 0.
delta_beta = 1.

####### THIS FAILS 
# vjp_evolve is a function but I am not sure how to use it to get perturbations delta_ys given y0/parameter variations
vjp_evolve(delta_y0,delta_rho,delta_sigma,delta_beta)

That last line raises an error:

TypeError: The function returned by `jax.vjp` applied to evolve was called with 4 arguments, but functions returned by `jax.vjp` must be called with a single argument corresponding to the single value returned by evolve (even if that returned value is a tuple or other container).

For example, if we have:

  def f(x):
    return (x, x)
  _, f_vjp = jax.vjp(f, 1.0)

the function `f` returns a single tuple as output, and so we call `f_vjp` with a single tuple as its argument:

  x_bar, = f_vjp((2.0, 2.0))

If we instead call `f_vjp(2.0, 2.0)`, with the values 'splatted out' as arguments rather than in a tuple, this error can arise.

I suspect I am confused at the concept of reverse-mode VJP and what the input would be in the case of this vector-valued ODE. The same problem would persist if I had used diffrax solvers.

For what it's worth, I can reproduce the forward-mode JVP results on that website if I use a diffrax solver while specifying adjoint=NoAdjoint, so that jax.jvp can be used:

# I am similarly confused about how to use VJP with diffrax's default reverse-mode autodiff of the ODE system
# however I am able to use forward-mode JVP with diffrax's ODE solver if I specify adjoint=NoAdjoint

# diffrax expects reverse order for inputs (time first, then state, then args) -- opposite of jax odeint 
def f_diffrax(t, state, args):
    x, y, z = state
    rho, sigma, beta = args 
    return jnp.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])

# set up diffrax inputs as closely to jax-odeint as possible 
terms = ODETerm(f_diffrax)
t0 = 0.0
t1 = 1.0 
dt0 = None
max_steps = 16**3 # not sure if this is needed
tsave = SaveAt(ts=tarr,dense=True)

def evolve_diffrax(y0, rho, sigma, beta):
    return diffeqsolve(terms,Dopri5(),t0,t1,dt0,y0,jnp.array([rho,sigma,beta]),saveat=tsave,
                       stepsize_controller=PIDController(rtol=1.4e-8,atol=1.4e-8),max_steps=max_steps,adjoint=NoAdjoint())

# get solution AND differentials assuming the same changes in y0 and parameters as we tried (and failed) to get above 
diffrax_ys, diffrax_delta_ys = jvp(evolve_diffrax, (y0,rho,sigma,beta),(delta_y0,delta_rho,delta_sigma,delta_beta))

# get the actual solution arrays from the diffrax Solution objects 
diffrax_ys = diffrax_ys.ys
diffrax_delta_ys = diffrax_delta_ys.ys

# plot 
fig, ax = plt.subplots(1,figsize=(6,4),dpi=150,subplot_kw={'projection':'3d'})   
ax.plot(diffrax_ys.T[0],diffrax_ys.T[1],diffrax_ys.T[2],color='violet',lw=0.5)
ax.quiver(diffrax_ys.T[0][::10],diffrax_ys.T[1][::10],diffrax_ys.T[2][::10],
          diffrax_delta_ys.T[0][::10],diffrax_delta_ys.T[1][::10],diffrax_delta_ys.T[2][::10])
    

enter image description here

That reproduces one of the main plots of that website (showing that the ODE is very sensitive to variations in the beta parameter). So I understand the concept of forward-mode JVP (given perturbations in initial conditions and/or parameters, JVP gives the corresponding perturbation in the ODE solution as a function of time). But what does reverse-mode VJP do and what would be the correct input to the vjp_evolve function above?

Jim Raynor
  • 83
  • 6

1 Answers1

1

JVP is forward-mode autodiff: given tangents of the input to the function at a primal point, it returns tangents on the outputs.

VJP is reverse-mode autodiff: given cotangents on the output of the function at a primal point, it returns cotangents on the inputs.

So you can call vjp_evolve with cotangents of the same shape as vjp_ys:

print(vjp_evolve(jnp.ones_like(vjp_ys)))
(Array([ 1.74762118, 26.45747015, -2.03017559], dtype=float64),
 Array(871.66349663, dtype=float64),
 Array(-83.07586548, dtype=float64),
 Array(-1754.48788565, dtype=float64))

Conceptually, JVP propagates gradients forward through a computation, while VJP propagates gradients backward. The JAX docs might be useful background for understanding the JVP & VJP transformations more deeply: https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#vector-jacobian-products-vjps-aka-reverse-mode-autodiff

jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • Thank you! I did read that page and others but I'm still confused about the intuitive meaning of the input for your successful call to vjp_evolve. What does it mean that you're feeding in all ones? What if you fed in a mixture of ones and zeros or a bunch of random numbers? Is the VJP asking what corresponding change in the inputs would produce the requested change in the outputs (in your case, increasing all y's by 1)? But what if many such input changes can produce the requested output change??? – Jim Raynor Mar 12 '23 at 19:54
  • 1
    If you think of the inputs and outputs as a high-dimensional vector, then the values you feed in encode the direction and magnitude of the output gradient (i.e. the cotangent vector). If you put in different numbers, the result is evaluation of the VJP for a different cotangent. Regarding the effect on inputs: since VJP encodes operations over infinitesimal quantities, you're evaluating it in the linear regime, so there's only one possible input cotangent given an output cotangent at a given primal value. – jakevdp Mar 12 '23 at 21:54
  • Thanks! I think one of the things that's hard to understand here is that we're solving a vector-valued ODE (i.e., a system) which will produce N outputs (the multiple components of the solution sampled at different times) -- as opposed to just producing the vector-valued solution at a single final time. So the VJP in our case depends on the full history of the solution, and it sounds like jnp.ones_like(vjp_ys)) is saying to equally weight the contribution of the solution at each time to the total gradient of evolve. Is there a more explicit, intuitive way to see this, maybe using jax.jacrev? – Jim Raynor Mar 13 '23 at 04:41
  • 1
    In terms of dependence on the history of the system, the VJP isn't any different from the JVP. At any timestep your system has its value (the primals) and its derivatives (the tangents). JVP takes a primal value and maps input tangents to output tangents, and you are free to choose whatever values for the input tangents you want. The VJP takes a primal value and maps output tangents to input tangents, and you are free to choose whatever values for the output tangents that you want. Solution of a particular problem constrains those inputs, but that's external to the JVP/VJP operation. – jakevdp Mar 13 '23 at 13:17
  • Okay thank you -- I swear one last question. I have heard that JVP builds the Jacobian one column at a time and VJP one row at a time (so gradients aka the row vector of all derivatives of one component of the vector-valued function being differentiated). How does your vector of all ones for vjp_evolve relate to this? I have seen people use vectors like [0,0,1,0] for vjp and say that it "pulls out" the gradient of only the 3rd component. So is your vector of all ones somehow combining the gradients of all components and mapping that to the equivalent change in input parameters? – Jim Raynor Mar 13 '23 at 14:53
  • 1
    That's exactly right: the way that `jacrev` constructs the jacobian is by `vmap`-ing the `vjp` over an identity matrix (see the code [here](https://github.com/google/jax/blob/5aa74acbcd24785a72ecb3e8468fc3e6a1befea0/jax/_src/api.py#L1467) – `_std_basis` is basically an identity matrix in the output space), so each row of the jacobian effectively comes from calling `jvp` on a one-hot vector. Passing a vector of ones is like multiplying the jacobian matrix by a vector of ones, but without the intermediate step of constructing the full jacobian matrix. – jakevdp Mar 13 '23 at 21:13
  • Alright, I think I've simplified my problem in the context of the Jacobian of a vector-valued ODE system with N parameters and M state variables. If I restrict myself to only considering the ODE outputs at a single time (say the final time of integration), then the Jacobian is just an MxN matrix and I can easily compute how varying the N parameters changes the final value of the M variables. That is, have evolve() return odeint(f, y0, tarr, (rho, sigma, beta))`[-1]`. However, I still don't understand what the Jacobian means when evolve() returns the full time-dependent array of solution values – Jim Raynor Mar 13 '23 at 21:55