I would like to train the same neural network model in parallel instead of using a loop in Jax.
We construct the model as:
class Model(flax.linen.Module):
........
return out
# create a list that contains n times the NN model
models = [Model() for i in range(n)]
models_list = []
# loop through all the NNs
for f in models:
models_list.append(Model(x)) # x is the input of the NN
# stack the outputs of the list
y = jnp.stack(models_list, axis=1)
Instead of using the for loop how can we parallelize the jax model so that the n functions are optimized in parallel?
I tried using vmap to map all the elements in the models list onto the input batches, but I keep getting an error.
import jax.numpy as jnp
from jax import random
import jax
class FFN(nn.Module):
alpha: int = 1
@nn.compact
def __call__(self, x):
y = nn.Dense(features=self.alpha*x.shape[0])(x)
y = nn.relu(y)
return jnp.sum(y, axis=-1)
# Define random input tensor
key = random.PRNGKey(0)
batch_size = 16
input_shape = (32,)
x = random.normal(key, (batch_size,) + input_shape)
# Initialize model
model = FFN()
# Apply model to input
params = model.init(key, x)
output = model.apply(params, x)
# list of models
models = [FFN() for i in range(3)]
# loop through models list
for net in models:
net.apply(params, x)
# run models in parallel
def model_apply(model, n):
return model.apply(params, n)
out = jax.vmap(model_apply, in_axes=(0,None))(models, x)
ERROR:
ValueError: vmap was requested to map its argument along axis 0, which implies that
its rank should be at least 1, but is only 0 (its shape is ())