I am currently working with torch.nn.CrossEntropyLoss
. As far as I know, it is common to compute the loss batch-wise. However, is there a possibility to compute the loss over multiple batches?
More concretely, assume we are given the data
import torch
features = torch.randn(no_of_batches, batch_size, feature_dim)
targets = torch.randint(low=0, high=10, size=(no_of_batches, batch_size))
loss_function = torch.nn.CrossEntropyLoss()
Is there a way to compute in one line
loss = loss_function(features, targets) # raises RuntimeError: Expected target size [no_of_batches, feature_dim], got [no_of_batches, batch_size]
?
Thank you in advance!