-1
import jax
import numpy as np
import jax.numpy as jnp
a = []
a_jax = []
for i in range(10000):
 a.append(np.random.randint(1, 5, (5,)))
 a_jax.append(jnp.array(a[i]))

# a_jax = jnp.array(a_jax)
@jax.jit
def calc_add_with_jit(a, b):
 return a + b
def calc_add_without_jit(a, b):
 return a + b
def main_function_with_jit():
 for i in range(99):
  calc_add_with_jit(a_jax[i], a_jax[i+1]) 
def main_function_without_jit():
 for i in range(99):
  calc_add_without_jit(a[i], a[i+1])

%time calc_add_with_jit(a_jax[1], a_jax[2])
%time main_function_with_jit()
%time main_function_without_jit()

Now the first %time results in 3.33 ms wall time, Second %time function results in 5.58 ms of time, Third %time results in 156 microseconds of time

Can anyone explain why is this happening? Why is JAX-JIT slower compared to regular code? I am talking about second and third time function results

  • 1
    If all your time is spent in C code, how is JIT supposed to do any good? The point is to come up with a faster version of Python code, but when you just call numpy functions that are precompiled there's nothing for the JIT bits to change. The `for i in range(99)` loop, maybe, but that's so minimal that it's unsurprising for the overhead costs not to be paid back. – Charles Duffy Sep 11 '22 at 18:20

1 Answers1

1

This question is pretty well answered in the JAX documentation; see FAQ: Is JAX Faster Than NumPy? In particular, quoting from the summary:

if you’re doing microbenchmarks of individual array operations on CPU, you can generally expect NumPy to outperform JAX due to its lower per-operation dispatch overhead. If you’re running your code on GPU or TPU, or are benchmarking more complicated JIT-compiled sequences of operations on CPU, you can generally expect JAX to outperform NumPy.

You are benchmarking sequences of individually-dispatched single operations on CPU, which is precisely the regime that NumPy is designed and optimized for, and so you can expect that NumPy will be faster.

jakevdp
  • 77,104
  • 11
  • 125
  • 160