3

Is it possible to avoid recompiling a JIT function when the structure of its input remains essentially unchanged, aside from one axis having a varying number of elements?

import jax

@jax.jit
def f(x):
    print('recompiling')
    return (x + 10) * 100

a = f(jax.numpy.arange(300000000).reshape((-1, 2, 2)).block_until_ready()) # recompiling
b = f(jax.numpy.arange(300000000).reshape((-1, 2, 2)).block_until_ready())
c = f(jax.numpy.arange(450000000).reshape((-1, 2, 2)).block_until_ready()) # recompiling. It would be nice if it weren't

Requirements: pip install jax, jaxlib

mutableVoid
  • 1,284
  • 2
  • 11
  • 29
  • 1
    According to https://github.com/google/jax/issues/803 this doesn’t seem possible at the moment. The XLA compiler requires known shapes. – jkr Nov 26 '21 at 15:19

1 Answers1

5

No, there is no way to avoid recompilation when you call a function with arrays of a different shape. Fundamentally, JAX compiles functions for statically-shaped inputs and outputs, and calling a JIT-compiled function with an array of a new shape will always trigger re-compilation.

There is some ongoing work on relaxing this requirement (search "dynamic shapes" in JAX's github repository) but no such APIs are available at the moment.

jakevdp
  • 77,104
  • 11
  • 125
  • 160