This page appears pretty often on Google, so I thought I would try to improve on Yaroslav's answer by providing what I hope is a more clear answer for those just getting into distributed Tensorflow.
import tensorflow as tf
import threading
def main(job_name, task):
cluster = tf.train.ClusterSpec({
'ps': ['localhost:22222', 'localhost:22223'],
'worker': ['localhost: 22224','localhost: 22225','localhost: 22226']
})
server = tf.train.Server(cluster, job_name=job_name, task_index=task)
if job_name == 'ps':
# create a shared queue on the parameter server which is visible on /job:ps/task:%d
with tf.device('/job:ps/task:%d' % task):
queue = tf.FIFOQueue(cluster.num_tasks('worker'), tf.int32, shared_name='done_queue%d' % task)
# wait for the queue to be filled
with tf.Session(server.target) as sess:
for i in range(cluster.num_tasks('worker')):
sess.run(queue.dequeue())
print('ps:%d received "done" from worker:%d' % (task, i))
print('ps:%d quitting' % task)
elif job_name == 'worker':
queues = []
# create a shared queue on the worker which is visible on /job:ps/task:%d
for i in range(cluster.num_tasks('ps')):
with tf.device('/job:ps/task:%d' % i):
queues.append(tf.FIFOQueue(cluster.num_tasks('worker'), tf.int32, shared_name='done_queue%d' % i))
# fill the queue
with tf.Session(server.target) as sess:
for i in range(cluster.num_tasks('ps')):
_, size = sess.run([queues[i].enqueue(task), queues[i].size()])
print('Worker:%d sending "done" to ps:%d [elements=%d]' % (task, i, size))
if __name__ == '__main__':
threads = [
threading.Thread(target=main, args=('ps', 0)),
threading.Thread(target=main, args=('ps', 1)),
threading.Thread(target=main, args=('worker', 0)),
threading.Thread(target=main, args=('worker', 1)),
threading.Thread(target=main, args=('worker', 2))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
It's pretty simple to extend upon the "canonical" Distributed Tensorflow example by replacing the worker section of the code with this snippet:
# create a worker that does nothing
elif job_name == 'worker':
with tf.device(tf.train.replica_device_setter(worker_device='/job:worker/task:%d' % task, cluster=cluster)):
global_step = tf.train.get_or_create_global_step()
no_op = tf.no_op()
done_ops = []
# create a shared queue on the worker which is visible on /job:ps/task:%d
for i in range(cluster.num_tasks('ps')):
with tf.device('/job:ps/task:%d' % i):
done_queue = tf.FIFOQueue(cluster.num_tasks('worker'), tf.int32, shared_name='done_queue' + str(i))
done_ops.append(done_queue.enqueue(task))
hooks=[tf.train.StopAtStepHook(last_step=1),
tf.train.FinalOpsHook([done_ops])]
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=(task == 0),
hooks=hooks) as sess:
sess.run([no_op])
Note that the MonitoredTrainingSession version seems to be much slower at connecting all of the workers together.