This is a basic example.
@jax.jit
def block(arg1, arg2):
for x1 in range(cons1):
for x2 in range(cons2):
for x3 in range(cons3):
--do something--
return result
When cons are small, the compile-time is around a minute. With larger cons, compile time is much higher—10s of minutes. And I need even higher cons. What can be done? From what I am reading, the loops are the cause. They are unrolled at compile time. Are there any workarounds? There is also jax.fori_loop. But I don't understand how to use it. There is jax.experimental.loops module, but again I'm not able to understand it.
I am very new to all this. Hence, all help is appreciated. If you can provide some examples of how to use jax loops, that will be much appreciated.
Also, what is an ok compile time? Is it ok for it to be in minutes? In one of the examples, compile time is 262 seconds and remaining runs are ~0.1-0.2 seconds.
Any gain in runtime is overshadowed by the compile time.