This is slow because there are tradeoffs in the implementation of split()
, and your function happens to be on the wrong side of the tradeoff.
There are several ways to compute slices in XLA, including XLA:Slice (i.e. lax.slice
), XLA:DynamicSlice (i.e. lax.dynamic_slice
), and XLA:Gather (i.e. lax.gather
).
The main difference between these concerns whether the start and ending indices are static or dynamic. Static indices essentially mean you're specializing your computation for specific index values: this incurs some small compilation overhead on the first call, but subsequent calls can be very fast. Dynamic indices, on the other hand, don't include such specialization, so there is less compilation overhead, but each execution takes slightly longer. You may be able to guess where this is going...
jnp.split
currently is implemented in terms of lax.slice
(see code), meaning it uses static indices. This means that the first use of jnp.split
will incur compilation cost proportional to the number of outputs, but repeated calls will execute very quickly. This seemed like the best approach for common uses of split
, where a handful of arrays are produced.
In your case, you're generating hundreds of arrays, so the compilation cost far dominates over the execution.
To illustrate this, here are some timings for three approaches to the same array split, based on gather
, slice
, and dynamic_slice
. You might wish to use one of these directly rather than using jnp.split
if your program benefits from different implementations:
from timeit import default_timer as timer
from jax import lax
import jax.numpy as jnp
import jax
def f_slice(x, step=10):
return [lax.slice(x, (N,), (N + step,)) for N in range(0, x.shape[0], step)]
def f_dynamic_slice(x, step=10):
return [lax.dynamic_slice(x, (N,), (step,)) for N in range(0, x.shape[0], step)]
def f_gather(x, step=10):
step = jnp.asarray(step)
return [x[N: N + step] for N in range(0, x.shape[0], step)]
def time(f, x):
print(f.__name__)
for k in range(5):
start = timer()
segments = jax.block_until_ready(f(x))
end = timer()
print(f' call {k}: {end - start:0.2f} s')
x = jnp.ones(5000)
time(f_slice, x)
time(f_dynamic_slice, x)
time(f_gather, x)
Here's the output on a Colab CPU runtime:
f_slice
call 0: 7.78 s
call 1: 0.05 s
call 2: 0.04 s
call 3: 0.04 s
call 4: 0.04 s
f_dynamic_slice
call 0: 0.15 s
call 1: 0.12 s
call 2: 0.14 s
call 3: 0.13 s
call 4: 0.16 s
f_gather
call 0: 0.55 s
call 1: 0.54 s
call 2: 0.51 s
call 3: 0.58 s
call 4: 0.59 s
You can see here that static indices (lax.slice
) lead to the fastest execution after compilation. However, for generating many slices, dynamic_slice
and gather
avoid repeated compilations. It may be that we should re-implement jnp.split
in terms of dynamic_slice
, but that wouldn't come without tradeoffs: for example, it would lead to a slowdown in the (possibly more common?) case of few splits, where lax.slice
would be faster on both initial and subsequent runs. Also, dynamic_slice
only avoids recompilation if each slice is the same size, so generating many slices of varying sizes would incur a large compilation overhead similar to lax.slice
.
These kinds of tradeoffs are actively discussed in JAX development channels; a recent example very similar to this can be found in PR #12219. If you wish to weigh-in on this particular issue, I'd invite you to file a new jax issue on the topic.
A final note: if you're truly just interested in generating equal-length sequential slices of an array, you would be much better off just calling reshape
:
out = x.reshape(len(x) // 10, 10)
The result is now a 2D array where each row corresponds to a slice from the above functions, and this will far out-perform anything that's generating a list of array slices.