I am indexing vectors and using JAX, but I have noticed a considerable slow-down compared to numpy when simply indexing arrays. For example, consider making a basic array in JAX numpy and ordinary numpy:
import jax.numpy as jnp
import numpy as onp
jax_array = jnp.ones((1000,))
numpy_array = onp.ones(1000)
Then simply indexing between two integers, for JAX (on GPU) this gives a time of:
%timeit jax_array[435:852]
1000 loops, best of 5: 1.38 ms per loop
And for numpy this gives a time of:
%timeit numpy_array[435:852]
1000000 loops, best of 5: 271 ns per loop
So numpy is 5000 times faster than JAX. When JAX is on a CPU, then
%timeit jax_array[435:852]
1000 loops, best of 5: 577 µs per loop
So faster, but still 2000 times slower than numpy. I am using Google Colab notebooks for this, so there should not be a problem with the installation/CUDA.
Am I missing something? I realise that indexing is different for JAX and numpy, as given by the JAX 'sharp edges' documentation, but I cannot find any way to perform assignment such as
new_array = jax_array[435:852]
without a considerable slowdown. I cannot avoid indexing the arrays as it is necessary in my program.