0

Apologies in advance for how vague this question is (unfortunately I don't know enough about how jax tracing works to phrase it more precisely), but: Is there a way to completely insulate a function or code block from jax tracing?

For context, I have a function of the form:

def f(x, y):
   z = h(y)
   return g(x, z)

Essentially, I want to call g(x, z), and treat z as a constant when doing any jax transformations. However, setting up the argument z is very awkward, so the helper function h is used to transform an easier-to-specify input y into the format required by g. What I'd like is for jax to treat h as a non-traceable black box, so that doing jit(lambda x: f(x, y0)) for a particular y0 is the same as first computing z0 = h(y0) with numpy, then doing jit(lambda x: g(x, z0)) (and similar with grad or whatever other function transformations).

In my code, I've already written h to only use standard numpy (which I thought might lead to black-box behaviour), but the compile time of jit(lambda x: f(x, y0)) is noticeably longer than the compile time of jit(lambda x: g(x, z0)) for z0 = h(y0). I have a feeling the compile time may have something to do with jax tracing the many loops in h, though I'm not sure.

Some additional notes:

  • Writing h in a jax-friendly way would be awkward (input formatting is ragged, tons of looping/conditionals, output shape dependent on input value, etc) and ultimately more trouble than it's worth as the function is extremely cheap to execute, and I don't ever need to differentiate it (the input data is integer-based).

Thoughts?

Edit addition for clarity: I know there are maybe ways around this if, e.g. f is a top-level function. In this case it isn't such a big deal to get the user to call h first to "pre-compile" the jax-friendly inputs to g, then freely perform whatever jax transformations they want to lambda x: g(x, z0). However, I'm imagining cases in which we have many functions that we want to chain together, that have the same structure as f, where there are some jax-unfriendly inputs/computations, but these inputs will always be treated as constant to the jax part of the computation. In principle one could always pull out these pre-computations to set up the jax stuff, but this seems difficult if we have a non-trivial collection of functions of this type that will be calling each other.

Is there some way to control how f gets traced, so that while tracing it knows to just evaluate z=h(y) (instead of tracing h) then continue with tracing g(x, z)?

1 Answers1

0
f_jitted = jax.jit(f, static_argnums=1)

static_argnums parameter probably could help

https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

You can use transformation parameters such as static_argnums for jit to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.