Below code snippet is the custom training loop from Tensorflow official tutorial.https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch . Another tutorial also does not average loss over batch_size
, as shown here https://www.tensorflow.org/tutorials/customization/custom_training_walkthrough
Why is the loss_value not averaged over batch_size at this line loss_value = loss_fn(y_batch_train, logits)
? Is this a bug? From another question here Loss function works with reduce_mean but not reduce_sum, reduce_mean
is indeed needed to average loss over batch_size
The loss_fn
is defined in the tutorial as below. It obviously does not average over batch_size.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
From documentation, keras.losses.SparseCategoricalCrossentropy
sums loss over the batch without averaging. Thus, this is essentially reduce_sum
instead of reduce_mean
!
Type of tf.keras.losses.Reduction to apply to loss. Default value is AUTO. AUTO indicates that the reduction option will be determined by the usage context. For almost all cases this defaults to SUM_OVER_BATCH_SIZE.
The code is shown below.
epochs = 2
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
# Open a GradientTape to record the operations run
# during the forward pass, which enables auto-differentiation.
with tf.GradientTape() as tape:
# Run the forward pass of the layer.
# The operations that the layer applies
# to its inputs are going to be recorded
# on the GradientTape.
logits = model(x_batch_train, training=True) # Logits for this minibatch
# Compute the loss value for this minibatch.
loss_value = loss_fn(y_batch_train, logits)
# Use the gradient tape to automatically retrieve
# the gradients of the trainable variables with respect to the loss.
grads = tape.gradient(loss_value, model.trainable_weights)
# Run one step of gradient descent by updating
# the value of the variables to minimize the loss.
optimizer.apply_gradients(zip(grads, model.trainable_weights))
# Log every 200 batches.
if step % 200 == 0:
print(
"Training loss (for one batch) at step %d: %.4f"
% (step, float(loss_value))
)
print("Seen so far: %s samples" % ((step + 1) * 64))