I am working on a very wide and shallow computation graph with a relatively small number of shared parameters on a single machine. I would like to make the graph wider but am running out of memory. My understanding is that, by using Distributed Tensorflow, it is possible to split the graph between workers by using the tf.device context manager. However it's not clear how to deal with the loss, which can only be calculated by running the entire graph, and the training operation.
What would be the right strategy to train the parameters for this kind of model?