0

I have a feed-forward neural network which is basically a composition of N functions. I want to pipeline the training procedure of said network in a multi-device environment by executing some of these functions in one device, forwarding the result to the second, execute some more functions etc. So far, I think something like the following would work:

subfunctions = [a list of jit-ed functions, each of which executes one or more network layers]
input = some provided input
for f in subfunctions:
    input = f(input) #these get called asynchronously, right?

In addition, I need the final device to send back a "message" with backpropagated gradients to its previous device, which it in turn will also send back (after applying chain rule).

I also need these things to happen concurrently, i.e. call the function of device 1 again while device 2 is just beginning to process the input it got from device 1 (think of a pipelined execution).

Is there native support in Jax for such operations, or should I be looking into something like mpi4jax? Would that even work for me if I'm looking into managing, say, GPU devices and not CPU processes?

0 Answers0