due to the limiting of gpu, I want to update my weight after every two step training. Specifically, the network will firstly calculate the fisrt batch inputs and save the loss. And then the network calculate the next batch inputs and average these two losses and will update the weights once. It likes average_loss op in caffe, for example()fcn-berkeley . and how to calculate the batchnorm update-ops.
Asked
Active
Viewed 111 times
2 Answers
0
Please check this thread for correct info on Caffe's average_loss
.
You should be able to compute an averaged loss by subclassing LoggingTensorHook in a way like
class MyLoggingTensorHook(tf.train.LoggingTensorHook):
# set every_n_iter to if you want to average last 2 losses
def __init__(self, tensors, every_n_iter):
super().__init__(tensors=tensors, every_n_iter=every_n_iter)
# keep track of previous losses
self.losses=[]
def after_run(self, run_context, run_values):
_ = run_context
# assuming you have a tag like 'average_loss'
# as the name of your loss tensor
for tag in self._tag_order:
if 'average_loss' in tag:
self.losses.append(run_values.results[tag])
if self._should_trigger:
self._log_tensors(run_values.results)
self._iter_count += 1
def _log_tensors(self, tensor_values):
original = np.get_printoptions()
np.set_printoptions(suppress=True)
logging.info("%s = %s" % ('average_loss', np.mean(self.losses)))
np.set_printoptions(**original)
self.losses=[]
and attach it to an estimator's train method or use a TrainSpec.
You should be able to compute gradients of your variables normally in every step, but apply them in every N steps by conditioning on your global_state
variable that defines your current iteration or step (you should have initialized this variable in your graph by something like global_step = tf.train.get_or_create_global_step()
). Please see the usage of compute_gradients and apply_gradients for this.

kkirtac
- 1
- 1
0
Easy, juste use tf.reduce_mean(input_tensor)
and in your case, it will be :
loss = tf.concat([loss1,loss2], axis=0)
final_loss = tf.reduce_mean(loss, axis=0)

Tbertin
- 635
- 7
- 21