0

I'm working on a problem that involves computing the value of many interpolants on a three dimensional grid using jax. Following standard jax practice, I wrote everything for "single-batch" inputs and then vmap over all interpolants and evaluation grid points in the end. The code below is a reduced, non-sensical version of this.

from collections import namedtuple
from functools import partial

import jax.numpy as jnp
from jax import vmap
from jax.lax import dynamic_slice, stop_gradient

interpolation_params = namedtuple("interpolation_params", ["a", "dx", "f", "lb", "ub"])


@partial(vmap, in_axes=(None, None, 0))
def init_1d_interpolation_params(a, dx, f):
    f = jnp.pad(f, 1)
    lb, ub = a, a + (f.shape[0] - 1) * dx
    return interpolation_params(a=a, dx=dx, f=f, lb=lb, ub=ub)


@partial(vmap, in_axes=(None, 0))
def eval_interp1d(x, interpolation_params):
    A = jnp.array([-1.0 / 16, 9.0 / 16, 9.0 / 16, -1.0 / 16])
    B = jnp.array([1.0 / 24, -9.0 / 8, 9.0 / 8, -1.0 / 24])
    C = jnp.array([1.0 / 4, -1.0 / 4, -1.0 / 4, 1.0 / 4])
    D = jnp.array([-1.0 / 6, 1.0 / 2, -1.0 / 2, 1.0 / 6])

    x = (
        jnp.minimum(jnp.maximum(x, interpolation_params.lb), interpolation_params.ub)
        - interpolation_params.a
    )
    ix = jnp.atleast_1d(jnp.array(x // interpolation_params.dx, int))
    ratx = x / interpolation_params.dx - (ix + 0.5)
    asx = A + ratx * (B + ratx * (C + ratx * D))
    return jnp.dot(dynamic_slice(interpolation_params.f, ix, (4,)), asx)


# Init 300 interpolants on a uniform grid with 4096 points
x = jnp.linspace(0, 1, 4096)
f = x**2
ff = jnp.repeat(f.reshape(1, -1), 300, axis=0)

params = init_1d_interpolation_params(x[0], x[1] - x[0], ff)


@partial(vmap, in_axes=(0, None))
def foo(x, interpolation_params):
    g_x = (eval_interp1d(x, interpolation_params)) ** 2
    return jnp.sum(g_x)


large_x_array = stop_gradient(jnp.repeat(jnp.array([0.0]), 100**3))
foo(large_x_array, params)

Now, if I run the code below, I end up with very large memory foot prints (14GB) which is a little puzzling to me. Initially, I thought the issue is the computation graph tracing of jax's autodiff backend, the size of which should be naively comparable to the "cartesian product" of params and large_x_array. However, using stop_gradient to turn off the graph tracing didn't help, so I'm not exactly sure what's going on and how to fix this elegantly. Any thoughts on this?

crypty
  • 53
  • 4

1 Answers1

0

Your code is essentially a doubly-nested vmap over axes of shape 300 and 1000000 respectively. This means that the effective memory footprint of a fully-vmapped float32 scalar in your inner function is 300 * 1000000 * 4 bytes, or just over 1GB. Your inner function constructs fully-mapped arrays of length 4, which take up about 4GB – with that in mind, its not surprising that your full function would require allocating a few times that amount of memory.

For what it's worth, if you want to see the array sizes implied by your outer function, one way to do that is to construct the jaxpr representing your end-to-end operation:

import jax
print(jax.make_jaxpr(foo)(large_x_array, params))

The output is long, so I won't paste it in full here, but in the jaxpr you see direct evidence of what I said above, for example:

...
    cf:f32[1000000,300,4] = add ce cc
    cg:f32[1000000,300,4] = mul bv cf
...

These are the intermediate arrays of size 1000000x300x4x4 bytes (or just over 4 GB) which are allocated in the course of executing your code.

If you want to reduce memory consumption, you could do so by serializing some of the computations using scan or fori_loop in place of vmap, in order to avoid allocating the full 4GB intermediate arrays.

jakevdp
  • 77,104
  • 11
  • 125
  • 160