2

By default, TensorFlow distributed training establishes all-to-all connections between workers and parameter servers, even though in asynchronous distributed training, the only necessary communication is between each individual worker and the parameter servers.

How do I limit communication when I'm using tf.contrib.learn.Experiment?

rhaertel80
  • 8,254
  • 1
  • 31
  • 47

1 Answers1

2
# The easiest way to parse TF_CONFIG environment variable is to create a RunConfig.
# Unfortunately, it is an immutable object, so we're going to create a
# temporary one and only use it for `task_type` and `task_id`.
tmp = tf.contrib.learn.RunConfig()
task_type, task_id = tmp.task_type, tmp.task_id

# We use a device_filter to limit the communication between this job
# and the parameter servers, i.e., there is no need to directly
# communicate with the other workers; attempting to do so can result
# in reliability problems.
device_filters = [
    '/job:ps', '/job:%s/task:%d' % (task_type, task_id)
]
session_config = tf.ConfigProto(device_filters=device_filters)
run_config = tf.contrib.learn.RunConfig(
    model_dir=args.job_dir,
    session_config=session_config)

# Create the experiment_fn:
experiment_fn = ...

# Run the experiment
learn_runner.run(experiment_fn, run_config=run_config)
rhaertel80
  • 8,254
  • 1
  • 31
  • 47