I'm working on a problem that involves computing the value of many interpolants on a three dimensional grid using jax. Following standard jax practice, I wrote everything for "single-batch" inputs and then vmap over all interpolants and evaluation grid points in the end. The code below is a reduced, non-sensical version of this.
from collections import namedtuple
from functools import partial
import jax.numpy as jnp
from jax import vmap
from jax.lax import dynamic_slice, stop_gradient
interpolation_params = namedtuple("interpolation_params", ["a", "dx", "f", "lb", "ub"])
@partial(vmap, in_axes=(None, None, 0))
def init_1d_interpolation_params(a, dx, f):
f = jnp.pad(f, 1)
lb, ub = a, a + (f.shape[0] - 1) * dx
return interpolation_params(a=a, dx=dx, f=f, lb=lb, ub=ub)
@partial(vmap, in_axes=(None, 0))
def eval_interp1d(x, interpolation_params):
A = jnp.array([-1.0 / 16, 9.0 / 16, 9.0 / 16, -1.0 / 16])
B = jnp.array([1.0 / 24, -9.0 / 8, 9.0 / 8, -1.0 / 24])
C = jnp.array([1.0 / 4, -1.0 / 4, -1.0 / 4, 1.0 / 4])
D = jnp.array([-1.0 / 6, 1.0 / 2, -1.0 / 2, 1.0 / 6])
x = (
jnp.minimum(jnp.maximum(x, interpolation_params.lb), interpolation_params.ub)
- interpolation_params.a
)
ix = jnp.atleast_1d(jnp.array(x // interpolation_params.dx, int))
ratx = x / interpolation_params.dx - (ix + 0.5)
asx = A + ratx * (B + ratx * (C + ratx * D))
return jnp.dot(dynamic_slice(interpolation_params.f, ix, (4,)), asx)
# Init 300 interpolants on a uniform grid with 4096 points
x = jnp.linspace(0, 1, 4096)
f = x**2
ff = jnp.repeat(f.reshape(1, -1), 300, axis=0)
params = init_1d_interpolation_params(x[0], x[1] - x[0], ff)
@partial(vmap, in_axes=(0, None))
def foo(x, interpolation_params):
g_x = (eval_interp1d(x, interpolation_params)) ** 2
return jnp.sum(g_x)
large_x_array = stop_gradient(jnp.repeat(jnp.array([0.0]), 100**3))
foo(large_x_array, params)
Now, if I run the code below, I end up with very large memory foot prints (14GB) which is a little puzzling to me. Initially, I thought the issue is the computation graph tracing of jax's autodiff backend, the size of which should be naively comparable to the "cartesian product" of params
and large_x_array
. However, using stop_gradient
to turn off the graph tracing didn't help, so I'm not exactly sure what's going on and how to fix this elegantly. Any thoughts on this?