0

I just about understand unnested vmaps, but try as I may, and I have tried my darnedest, nested vmaps continue to elude me. Take the snippet from this text for example

enter image description here

I don't understand what the axis are in this case. Is the nested vmap(kernel, (0, None)) some sort of partial function application? Why is the function mapped twice? Can someone please explain what is going on behind the scene in other words. What does a nested vmap desugar to?? All the answers that I have found are variants of the same curt explanation: mapping over both axis, which I am struggling with.

Olumide
  • 5,397
  • 10
  • 55
  • 104

1 Answers1

4

Each time vmap is applied, it maps over a single axis. So say for simplicity that you have a function that takes two scalars and outputs a scalar:

def f(x, y):
  assert jnp.ndim(x) == jnp.ndim(y) == 0  # x and y are scalars
  return x + y

print(f(1, 2))
# 0

If you want to apply this function to a single x value and an array of y values, you can do this with vmap:

f_mapped_over_x = jax.vmap(f, in_axes=(0, None))

x = jnp.arange(5)
print(f_mapped_over_x(x, 1))
# [1 2 3 4 5]

in_axes=(0, None) means that it is mapped along the leading axis of the first argument, x, and there is no mapping of the second argument, y.

Likewise, if you want to apply this function to a single x value and an array of y values, you can specify this via in_axes:

f_mapped_over_y = jax.vmap(f, in_axes=(None, 0))

y = jnp.arange(5, 10)
print(f_mapped_over_y(1, y))
# [ 6  7  8  9 10]

If you wish to map the function over both arrays at once, you can do this by specifying in_axes=(0, 0), or equivalently in_axes=0:

f_mapped_over_x_and_y = jax.vmap(f, in_axes=(0, 0))

print(f_mapped_over_x_and_y(x, y))
# [ 5  7  9 11 13]

But suppose you want to map first over x, then over y, to get a sort of "outer-product" version of the function. You can do this via a nested vmap, first mapping over just x, then mapping over just y:

f_mapped_over_x_then_y = jax.vmap(jax.vmap(f, in_axes=(None, 0)), in_axes=(0, None))

print(f_mapped_over_x_then_y(x, y))
# [[ 5  6  7  8  9]
#  [ 6  7  8  9 10]
#  [ 7  8  9 10 11]
#  [ 8  9 10 11 12]
#  [ 9 10 11 12 13]]

The nesting of vmaps is what lets you map over two axes separately.

jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • I think part of what was confusing me is that it is difficult (for me) not to see nested vmaps as a case of composition. And my understanding of composition is to apply the nested function first. However in nested vmaps the outer vmap applies first so to speak. Applying this to thinking to `f_mapped_over_x_then_y` the outer vmap takes a row of the first argument (i.e. scalar) and the entirely of second argument (an array) and pass them the inner vmap that takes the entirety of the first argument (the scalar) and a row of the second argument (i.e. a scalar) and pass them to the function `f`. – Olumide Aug 04 '22 at 00:44
  • The possible fly in the ointment of this line of thinking is the question of whether `x` and `y` are row or column vectors. jax.numpy is modelled after numpy where numpy arrays are row major order by default. How then should I think of the following: `x[1][None]` outputs `DeviceArray([1], dtype=int32)`, `x[None][1]` outputs `DeviceArray([0, 1, 2, 3, 4], dtype=int32)`, and to make things more interesting, `x[None][1][None]` outputs `DeviceArray([[0, 1, 2, 3, 4]], dtype=int32)`. Is this part of the machinery required to make vmaps work? – Olumide Aug 04 '22 at 01:59
  • No, the `None` used in `in_axes` has nothing to do with the `None` (i.e. `np.newaxis`) in numpy indexing, except for the coincidence that `None` is the value used as a special sentinel in both cases. There are no row or column vectors in the `vmap` example: only 1D vectors being passed to transformed functions. – jakevdp Aug 04 '22 at 03:07
  • You are right. It is sufficient to speak axis and dimensions. The notion of columns and rows is something I picked up from a blog post somewhere in the interweb. – Olumide Aug 05 '22 at 01:57