In the tutorial code of ParameterServerTraining from tensorflow API, has the following snippet of code in model.fit
section
def dataset_fn(input_context):
global_batch_size = 64
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
x = tf.random.uniform((10, 10))
y = tf.random.uniform((10,))
dataset = tf.data.Dataset.from_tensor_slices((x, y)).shuffle(10).repeat()
dataset = dataset.shard(
input_context.num_input_pipelines,
input_context.input_pipeline_id)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(2)
return dataset
dc = tf.keras.utils.experimental.DatasetCreator(dataset_fn)
and it's also said that
The code in dataset_fn will be invoked on the input device, which is usually the CPU, on each of the worker machines.
Does that mean the dataset must be on the same storage of every worker server (say the parameter server and the worker server are different machines)?
Or is there any way the parameter server on one machine can send the data for training to workers without the worker machines directly store the dataset here in ParameterServerStrategy that I don't understand?