Following the answer to this post, the following function that 'f_switch' that dynamically switches between multiple functions based on an index array is defined (based on 'jax.lax.switch'):
import jax
from jax import vmap;
import jax.random as random
def g_0(x, y, z, u): return x + y + z + u
def g_1(x, y, z, u): return x * y * z * u
def g_2(x, y, z, u): return x - y + z - u
def g_3(x, y, z, u): return x / y / z / u
g_i = [g_0, g_1, g_2, g_3]
@jax.jit
def f_switch(i, x, y, z, u):
g = lambda i: jax.lax.switch(i, g_i, x, y, z, u)
return jax.vmap(g)(i)
With input arrays: i_ar of shape (len_i,), x_ar y_ar and z_ar of shapes (len_xyz,) and u_ar of shape (len_u, len_xyz), out = f_switch(i_ar, x_ar, y_ar, z_ar, u_ar), yields out of shape (len_i, len_xyz, len_u):
len_i = 50
i_ar = random.randint(random.PRNGKey(5), shape=(len_i,), minval=0, maxval= len(g_i)) #related to
len_xyz = 3000
x_ar = random.uniform(random.PRNGKey(0), shape=(len_xyz,))
y_ar = random.uniform(random.PRNGKey(1), shape=(len_xyz,))
z_ar = random.uniform(random.PRNGKey(2), shape=(len_xyz,))
len_u = 1000
u_0 = random.uniform(random.PRNGKey(3), shape=(len_u,))
u_1 = jnp.repeat(u_0, len_xyz)
u_ar = u_1.reshape(len_u, len_xyz)
out = f_switch(i_ar, x_ar, y_ar, z_ar, u_ar)
print('The shape of out is', out.shape)
This worked. **But, How can the f_switch function be defined such that the result out of out = f_switch(i_ar, x_ar, y_ar, z_ar, u_ar) has a shape of (j_len, k_len, l_len) when the function is applied along the following axes: i_ar[j], x_ar[j], y_ar[j, k], z_ar[j, k], u_ar[l]? I am not sure about how ** Examples of these input arrays are here:
j_len = 82;
k_len = 20;
l_len = 100;
i_ar = random.randint(random.PRNGKey(0), shape=(j_len,), minval=0, maxval=len(g_i))
x_ar = random.uniform(random.PRNGKey(1), shape=(j_len,))
y_ar = random.uniform(random.PRNGKey(2), shape=(j_len,k_len))
z_ar = random.uniform(random.PRNGKey(3), shape=(j_len,k_len))
u_ar = random.uniform(random.PRNGKey(4), shape=(l_len,))
I tried to resolve this (i.e. with given input array to get output of shape: (j_len, k_len, l_len), with a nested vmap:
@jax.jit
def f_switch(i, x, y, z, u):
g = lambda i, x, y, z, u: jax.lax.switch(i, g_i, x, y, z, u)
g_map = jax.vmap(g, in_axes=(None, 0, 0, 0, 0))
wrapper = lambda x, y, z, u: g_map(i, x, y, z, u)
return jax.vmap(wrapper, in_axes=(0, None, None, None, 0))(x, y, z, u)
and to broadcast u_ar: u_ar_broadcast = jnp.broadcast_to(u_ar, (j_len, k_len, l_len))
, and then apply it inside of the original f_switch. But, both of these attempts failed.