Each time vmap
is applied, it maps over a single axis. So say for simplicity that you have a function that takes two scalars and outputs a scalar:
def f(x, y):
assert jnp.ndim(x) == jnp.ndim(y) == 0 # x and y are scalars
return x + y
print(f(1, 2))
# 0
If you want to apply this function to a single x
value and an array of y
values, you can do this with vmap
:
f_mapped_over_x = jax.vmap(f, in_axes=(0, None))
x = jnp.arange(5)
print(f_mapped_over_x(x, 1))
# [1 2 3 4 5]
in_axes=(0, None)
means that it is mapped along the leading axis of the first argument, x
, and there is no mapping of the second argument, y
.
Likewise, if you want to apply this function to a single x
value and an array of y
values, you can specify this via in_axes
:
f_mapped_over_y = jax.vmap(f, in_axes=(None, 0))
y = jnp.arange(5, 10)
print(f_mapped_over_y(1, y))
# [ 6 7 8 9 10]
If you wish to map the function over both arrays at once, you can do this by specifying in_axes=(0, 0)
, or equivalently in_axes=0
:
f_mapped_over_x_and_y = jax.vmap(f, in_axes=(0, 0))
print(f_mapped_over_x_and_y(x, y))
# [ 5 7 9 11 13]
But suppose you want to map first over x
, then over y
, to get a sort of "outer-product" version of the function. You can do this via a nested vmap
, first mapping over just x
, then mapping over just y
:
f_mapped_over_x_then_y = jax.vmap(jax.vmap(f, in_axes=(None, 0)), in_axes=(0, None))
print(f_mapped_over_x_then_y(x, y))
# [[ 5 6 7 8 9]
# [ 6 7 8 9 10]
# [ 7 8 9 10 11]
# [ 8 9 10 11 12]
# [ 9 10 11 12 13]]
The nesting of vmaps is what lets you map over two axes separately.