1

I would like to train the same neural network model in parallel instead of using a loop in Jax.

We construct the model as:

class Model(flax.linen.Module):
    ........
   return out


# create a list that contains n times the NN model
models = [Model() for i in range(n)]

          
models_list = []

# loop through all the NNs
for f in models: 
   models_list.append(Model(x)) # x is the input of the NN
            
# stack the outputs of the list
y = jnp.stack(models_list, axis=1)

Instead of using the for loop how can we parallelize the jax model so that the n functions are optimized in parallel?

I tried using vmap to map all the elements in the models list onto the input batches, but I keep getting an error.

import jax.numpy as jnp
from jax import random
import jax


class FFN(nn.Module):
    alpha: int = 1

    @nn.compact
    def __call__(self, x):
        y = nn.Dense(features=self.alpha*x.shape[0])(x)
        y = nn.relu(y)
        return jnp.sum(y, axis=-1)


# Define random input tensor
key = random.PRNGKey(0)
batch_size = 16
input_shape = (32,)
x = random.normal(key, (batch_size,) + input_shape)

# Initialize model
model = FFN()

# Apply model to input
params = model.init(key, x)
output = model.apply(params, x)


# list of models

models = [FFN() for i in range(3)]

# loop through models list
for net in models:
    net.apply(params, x)
    


# run models in parallel
def model_apply(model, n):
    return model.apply(params, n)
            
            
out = jax.vmap(model_apply, in_axes=(0,None))(models, x)

ERROR:

ValueError: vmap was requested to map its argument along axis 0, which implies that 
its rank should be at least 1, but is only 0 (its shape is ())
desertnaut
  • 57,590
  • 26
  • 140
  • 166
relaxon
  • 141
  • 6
  • There are several issues with this approach: 1) vmap works only for arrays not list. Either input jax numpy array or numpy array 2) vmap works only for arrays with numeric values. So you can't pass function in this format what you can do instead is: 1) Duplicate the data and pass the model object with in_axis None 2) Use pmap if there are multiple GPUs or TPU cores – prahasanam_boi Aug 03 '23 at 12:02

0 Answers0