5

I am currently working with torch.nn.CrossEntropyLoss. As far as I know, it is common to compute the loss batch-wise. However, is there a possibility to compute the loss over multiple batches?

More concretely, assume we are given the data

import torch

features = torch.randn(no_of_batches, batch_size, feature_dim)
targets = torch.randint(low=0, high=10, size=(no_of_batches, batch_size))

loss_function = torch.nn.CrossEntropyLoss()

Is there a way to compute in one line

loss = loss_function(features, targets) # raises RuntimeError: Expected target size [no_of_batches, feature_dim], got [no_of_batches, batch_size]

?

Thank you in advance!

Vivian
  • 335
  • 2
  • 10

1 Answers1

4

You can compute multiple cross-entropy losses but you'll need to do your own reduction. Since cross-entropy loss assumes the feature dim is always the second dimension of the features tensor you will also need to permute it first.

loss_function = torch.nn.CrossEntropyLoss(reduction='none')
loss = loss_function(features.permute(0,2,1), targets).mean(dim=1)

which will result in a loss tensor with no_of_batches entries.

jodag
  • 19,885
  • 5
  • 47
  • 66