Actually you can implement multi GPU in model_fn function same as before.
You can find full code in here. It is support multi threading queue reader and multi GPU to very high speed training when using estimator.
Code snippet: (GET FULL CODE)
def model_fn(features, labels, mode, params):
# network
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=params['num_classes'],
weight_decay=0.00004,
is_training=(mode == tf.estimator.ModeKeys.TRAIN))
# if predict. Provide an estimator spec for `ModeKeys.PREDICT`.
if mode == tf.estimator.ModeKeys.PREDICT:
logits, end_points = network_fn(features)
return tf.estimator.EstimatorSpec(mode=mode, predictions={"output": logits})
# Create global_step and lr
global_step = tf.train.get_global_step()
learning_rate = get_learning_rate("exponential", FLAGS.base_lr,
global_step, decay_steps=10000)
# Create optimizer
optimizer = get_optimizer(FLAGS.optimizer, learning_rate)
# Multi GPU support - need to make sure that the splits sum up to
# the batch size (in case the batch size is not divisible by
# the number of gpus. This code will put remaining samples in the
# last gpu. E.g. for a batch size of 15 with 2 gpus, the splits
# will be [7, 8].
batch_size = tf.shape(features)[0]
split_size = batch_size // len(params['gpus_list'])
splits = [split_size, ] * (len(params['gpus_list']) - 1)
splits.append(batch_size - split_size * (len(params['gpus_list']) - 1))
# Split the features and labels
features_split = tf.split(features, splits, axis=0)
labels_split = tf.split(labels, splits, axis=0)
tower_grads = []
eval_logits = []
with tf.variable_scope(tf.get_variable_scope()):
for i in xrange(len(params['gpus_list'])):
with tf.device('/gpu:%d' % i):
with tf.name_scope('%s_%d' % ("classification", i)) as scope:
# model and loss
logits, end_points = network_fn(features_split[i])
tf.losses.softmax_cross_entropy(labels_split[i], logits)
update_ops = tf.get_collection(
tf.GraphKeys.UPDATE_OPS, scope)
updates_op = tf.group(*update_ops)
with tf.control_dependencies([updates_op]):
losses = tf.get_collection(tf.GraphKeys.LOSSES, scope)
total_loss = tf.add_n(losses, name='total_loss')
# reuse var
tf.get_variable_scope().reuse_variables()
# grad compute
grads = optimizer.compute_gradients(total_loss)
tower_grads.append(grads)
# for eval metric ops
eval_logits.append(logits)
# We must calculate the mean of each gradient. Note that this is the
# synchronization point across all towers.
grads = average_gradients(tower_grads)
# Apply the gradients to adjust the shared variables.
apply_gradient_op = optimizer.apply_gradients(
grads, global_step=global_step)
# Track the moving averages of all trainable variables.
variable_averages = tf.train.ExponentialMovingAverage(0.9999, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
# Group all updates to into a single train op.
train_op = tf.group(apply_gradient_op, variables_averages_op)
# Create eval metric ops
_predictions = tf.argmax(tf.concat(eval_logits, 0), 1)
_labels = tf.argmax(labels, 1)
eval_metric_ops = {
"acc": slim.metrics.streaming_accuracy(_predictions, _labels)}
# Provide an estimator spec for `ModeKeys.EVAL` and `ModeKeys.TRAIN` modes.
return tf.estimator.EstimatorSpec(
mode=mode,
loss=total_loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops)