For those who don't want to read the whole story:
TL; DR: When using TF Estimator
, do we have to scale learning rate by the factor by which we increase batch size (I know this is the right way, I am not sure if TF handles this internally)? Similarly, do we have to scale per example loss by global batch size (batch_size_per_replica * number of replicas)?
Documentation on Tensorflow distributed learning is confusing. I need clarification on below points.
It is now understood that if you increase the batch size by a factor of
k
then you need to increase the learning rate byk
(see this and this paper). However, Tensoflow official page on distributed learning makes no clarifying comment about this. They do mention here that learning rate needs to be adjusted. Do they handle the learning rate scaling by themselves? To make matters more complicated, the behavior is different in Keras and tf.Estimator (see next point). Any suggestions on should I increase the LR by a factor of K or not when I am usingtf.Estimator
?It is widely accepted that the per example loss should be scaled by
global_batch_size = batch_size_per_replica * number of replicas
. Tensorflow mentions it here but then when illustrating how to achieve this with a tf.Estimator, they either forget or the scaling byglobal_batch_size
is not required. See here, in the code snippet, loss is defined as follows.
loss = tf.reduce_sum(loss) * (1. / BATCH_SIZE)
and BATCH_SIZE
to the best of my understanding is defined above as per replica batch size.
To complicate things further, the scaling is handled automatically if you are using Keras (for reasons I will never understand, it would have been better to keep everything consistent).