0

I am indexing vectors and using JAX, but I have noticed a considerable slow-down compared to numpy when simply indexing arrays. For example, consider making a basic array in JAX numpy and ordinary numpy:

import jax.numpy as jnp
import numpy as onp 
jax_array = jnp.ones((1000,))
numpy_array = onp.ones(1000)

Then simply indexing between two integers, for JAX (on GPU) this gives a time of:

%timeit jax_array[435:852]

1000 loops, best of 5: 1.38 ms per loop

And for numpy this gives a time of:

%timeit numpy_array[435:852]

1000000 loops, best of 5: 271 ns per loop

So numpy is 5000 times faster than JAX. When JAX is on a CPU, then

%timeit jax_array[435:852]

1000 loops, best of 5: 577 µs per loop

So faster, but still 2000 times slower than numpy. I am using Google Colab notebooks for this, so there should not be a problem with the installation/CUDA.

Am I missing something? I realise that indexing is different for JAX and numpy, as given by the JAX 'sharp edges' documentation, but I cannot find any way to perform assignment such as

new_array = jax_array[435:852]

without a considerable slowdown. I cannot avoid indexing the arrays as it is necessary in my program.

Adrian Mole
  • 49,934
  • 160
  • 51
  • 83
  • are you using `jit` here? – joel Aug 27 '21 at 10:16
  • @joel I am in my main program, but in this example I have provided, I am not using any extra code. So no `jit` is being used. I think `jit` is for functions anyway, and this is more to do with indexing with JAX numpy, but correct me if I'm wrong. – Danny Williams Aug 27 '21 at 10:41
  • indexing is a function call - `jax_array[435:852]` becomes `jnp.ndarray.__getitem__(jax_array, slice(435, 852))` unless `jax` does strange things under the hood. I can imagine `jit` could affect performance here but I don't actually know – joel Aug 27 '21 at 12:11

1 Answers1

3

The short answer: to speed things up in JAX, use jit.

The long answer:

You should generally expect single operations using JAX in op-by-op mode to be slower than similar operations in numpy. This is because JAX execution has some amount of fixed per-python-function-call overhead involved in pushing compilations down to XLA.

Even seemingly simple operations like indexing are implemented in terms of multiple XLA operations, which (outside JIT) will each add their own call overhead. You can see this sequence using the make_jaxpr transform to inspect how the function is expressed in terms of primitive operations:

from jax import make_jaxpr
f = lambda x: x[435:852]
make_jaxpr(f)(jax_array)
# { lambda  ; a.
#   let b = broadcast_in_dim[ broadcast_dimensions=(  )
#                             shape=(1,) ] 435
#       c = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(), start_index_map=(0,))
#                   indices_are_sorted=True
#                   slice_sizes=(417,)
#                   unique_indices=True ] a b
#       d = broadcast_in_dim[ broadcast_dimensions=(0,)
#                             shape=(417,) ] c
#   in (d,) }

(See Understanding Jaxprs for info on how to read this).

Where JAX outperforms numpy is not in single small operations (in which JAX dispatch overhead dominates), but rather in sequences of operations compiled via the jit transform. So, for example, compare the JIT-compiled versus not-JIT-compiled version of the indexing:

%timeit f(jax_array).block_until_ready()
# 1000 loops, best of 5: 612 µs per loop

f_jit = jit(f)
f_jit(jax_array)  # trigger compilation
%timeit f_jit(jax_array).block_until_ready()
# 100000 loops, best of 5: 4.34 µs per loop

(note that block_until_ready() is required for accurate micro-benchmarks because of JAX's asynchronous dispatch)

JIT-compiling this code gives a 150x speedup. It's still not as fast as numpy because of JAX's few-millisecond dispatch overhead, but with JIT that overhead is incurred only once. And when you move past microbenchmarks to more complicated sequences of real-world computations, those few milliseconds will no longer dominate, and the optimization provided by the XLA compiler can make JAX far faster than the equivalent numpy computation.

jakevdp
  • 77,104
  • 11
  • 125
  • 160