0

I've a function that works on batches of an array defined like this

def batched_fn(X):
  @jax.jit
  def apply(Xb):
    Xb_out = ...
    return Xb_out
  return apply

The apply function will use X and Xb to calculate Xb_out and can be called on a batch like this:

n = X.shape[0]
batches = []
batch_apply = batched_fn(X)
for i in range(0, n, batch_size):
  s = slice(i, min(i+batch_size, n))
  Xb = batch_apply(X[s])
  batches.append(Xb)
X_out = jnp.concatenate(batches, axis=0)

I tried to rewrite the above using jax.vmap like this

func = batched_fn(X)
X_out = jax.vmap(func)(X)

This seems to call func with only one row and not a batch of rows!

What is the proper way to batch a jax array?

bachr
  • 5,780
  • 12
  • 57
  • 92
  • 1
    Control of batching is achieved by the shape of the array: `jax.vmap(func, in_axes=(0,))(X.reshape(batch_count, batch_size))` However, the limiting factor is how much the primative vector operations on the single core can handle at once. If you want shared parallelism with multi-threading across cores, `jax.pmap` is what you want. – DavidJ Dec 15 '22 at 10:25

1 Answers1

1

It sounds like this is working as expected: vmap is not a batching transform in the way you're thinking about it, but rather a vectorizing transform that is equivalent to calling a function one row at a time.

jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • yes that's what I did realize it is doing, but then how can I do batching more then one row and use vmap to process the batches in parallel? – bachr Dec 12 '22 at 09:17
  • 1
    I'm not sure how to answer that question: it seems based on the premise that vmap leads to parallel execution, which is incorrect. So maybe the best answer is: you can't use vmap to do what you have in mind. – jakevdp Dec 12 '22 at 12:48
  • I guess i cannot just use this code with vmp, I should maybe reshape my array to add an extra dimension for the batches and update the code handle this, then vmap will run in parallel on this batch dimension – bachr Dec 12 '22 at 14:57
  • There are two aspects to this - Jake is cutting to the chase (second). First, to use vmap with batching you will need to reshape the state array so the the number of batches (or more practically the batch size along the second axis) is what you want, and vmap over that axis. Note you may have to pad the array will a fill to ensure the reshape operation. Second, batching with vmap makes no sense - either it can all be vectorised or not. Batching only makes sense with a sequential (jax.lax.map) or parallel (jax.pmap) outer scan/accumulate. – DavidJ Dec 15 '22 at 10:15