5

I want to use vmap to vectorise this code for performance.

def matrix(dataA, dataB):
    return jnp.array([[func(a, b) for b in dataB] for a in dataA])
matrix(data, data)

I tried this:

def f(x, y):
    return func(x, y)
mapped = jax.vmap(f)
mapped(data, data)

But this only gives the diagonal entries.

Basically I have a vector data = [1,2,3,4,5] (example), I want to get a matrix such that each entry (i, j) of the matrix is f(data[i], data[j]). Thus, the resulting matrix shape will be (len(data), len(data)).

akkh
  • 140
  • 1
  • 8

1 Answers1

3

jax.vmap maps across one set of axes at a time. If you want to map across two independent sets of axes, you can do so by nesting two vmap transformations:

mapped = jax.vmap(jax.vmap(f, in_axes=(None, 0)), in_axes=(0, None))
result = mapped(data, data)
jakevdp
  • 77,104
  • 11
  • 125
  • 160