Apologies in advance for how vague this question is (unfortunately I don't know enough about how jax tracing works to phrase it more precisely), but: Is there a way to completely insulate a function or code block from jax tracing?
For context, I have a function of the form:
def f(x, y):
z = h(y)
return g(x, z)
Essentially, I want to call g(x, z)
, and treat z
as a constant when doing any jax transformations. However, setting up the argument z
is very awkward, so the helper function h
is used to transform an easier-to-specify input y
into the format required by g
. What I'd like is for jax to treat h
as a non-traceable black box, so that doing jit(lambda x: f(x, y0))
for a particular y0
is the same as first computing z0 = h(y0)
with numpy
, then doing jit(lambda x: g(x, z0))
(and similar with grad
or whatever other function transformations).
In my code, I've already written h
to only use standard numpy
(which I thought might lead to black-box behaviour), but the compile time of jit(lambda x: f(x, y0))
is noticeably longer than the compile time of jit(lambda x: g(x, z0))
for z0 = h(y0)
. I have a feeling the compile time may have something to do with jax tracing the many loops in h
, though I'm not sure.
Some additional notes:
- Writing
h
in a jax-friendly way would be awkward (input formatting is ragged, tons of looping/conditionals, output shape dependent on input value, etc) and ultimately more trouble than it's worth as the function is extremely cheap to execute, and I don't ever need to differentiate it (the input data is integer-based).
Thoughts?
Edit addition for clarity: I know there are maybe ways around this if, e.g. f
is a top-level function. In this case it isn't such a big deal to get the user to call h
first to "pre-compile" the jax-friendly inputs to g
, then freely perform whatever jax transformations they want to lambda x: g(x, z0)
. However, I'm imagining cases in which we have many functions that we want to chain together, that have the same structure as f
, where there are some jax-unfriendly inputs/computations, but these inputs will always be treated as constant to the jax part of the computation. In principle one could always pull out these pre-computations to set up the jax stuff, but this seems difficult if we have a non-trivial collection of functions of this type that will be calling each other.
Is there some way to control how f
gets traced, so that while tracing it knows to just evaluate z=h(y)
(instead of tracing h
) then continue with tracing g(x, z)
?