0

I am using a GPT2 model that outputs logits (before softmax) in the shape (batch_size, num_input_ids, vocab_size) and I need to compare it with the labels that are of shape (batch_size, num_input_ids) to calculate BCELoss. How do I calculate it?

logits = output.logits #--of shape (32, 56, 592)
logits = torch.nn.Softmax()(logits)
labels = labels #---------of shape (32, 56)

torch.nn.BCELoss()(logits, labels)

but the dimensions do not match, so how do I contract logits to labels shape or expand labels to logits shape?

MNK
  • 634
  • 4
  • 18
  • Are the labels binary, i.e., 0-1? Then why does logit has shape 592? If these shapes are correct, then why are your using BinaryCrossEntropy loss? – Umang Gupta Jan 26 '23 at 21:23
  • It's hard to say what's going on without understanding what each dimension represents. Is it batch x feature x channels? – Megan Hardy Jan 26 '23 at 23:01

1 Answers1

0

Binary cross-entropy is used when the final classification layer is a sigmoid layer, i.e., for each output dimension, only a true/false output is possible. You can imagine it as assigning some tags to the input. This also means that the labels need to have the same dimension as the logits, having 0/1 for each logit. Statistically speaking, for 592 output dimensions, you predict 592 Bernoulli (= binary) distributions. The expected shape is 32 × 56 × 592.

When using the softmax layer, you assume only one target class is possible; you predict a single categorical distribution over 592 possible output classes. However, in this case, the correct loss function is not binary cross-entropy but categorical cross-entropy, implemented by the CrossEntropyLoss class in PyTorch. Note that it takes the logits directly before the softmax normalization and does the normalization internally. The expected shape is 32 × 56, as in the code snippet.

Jindřich
  • 10,270
  • 2
  • 23
  • 44
  • This makes sense. I one-hot encoded the `labels` onto the `vocab_size` so the dimensions of `labels` are the same as `logits`. Then calculated `CrossEntropyLoss` between them. Does that sound correct? – MNK Jan 27 '23 at 16:00
  • No, you eithe use one-hot labels with BCE or integer labels with cross-entropy. – Jindřich Jan 27 '23 at 21:27