0

I have an output tensor (both target and predicted) of dimension (32 x 8 x 5000). Here, the batch size is 32, the number of classes is 5000 and the number of points per batch is 8. I want to calculate CELoss on this in such a way that, the loss is computed for every point (across 5000 classes) and then averaged across the 8 points. How can I do this?

For clarity, there are 32 batch points in a batch (for bs=32). Each batch point has 8 vector points, and each vector point has 5000 classes. For a given batch, I wish to compute CELoss across all (8) vector points, compute their average and do so for all the batch points (32).

Let me know if my question isn’t clear or ambiguous.

For example:

op = torch.rand((4,3,5))

gt = torch.tensor([
    [[0,1,1,0,0],[0,0,1,0,0],[1,1,0,0,1]],
    [[1,1,0,0,1],[0,0,0,1,0],[0,0,1,0,0]],
    [[0,0,1,0,0],[1,1,1,1,0],[1,1,0,0,1]],
    [[1,1,0,0,1],[1,1,0,0,1],[1,0,0,0,0]]
])
helloworld
  • 150
  • 1
  • 7
  • can you provide a reproducible example with a reduced-size batch? – Salvatore Daniele Bianco Oct 10 '22 at 12:43
  • [Salvatore](https://stackoverflow.com/users/11728488/salvatore-daniele-bianco) I've updated the question. Let me know if something else would be needed – helloworld Oct 10 '22 at 14:14
  • sorry, but I expect to have a single true class in `gt` for each point of each element instead of multiple class in the vectors. Am I wrong? I mean that `gt.sum(dim=-1)` has to return a tensor of ones with shape `batch_size x points` – Salvatore Daniele Bianco Oct 10 '22 at 14:41
  • In this case, I'm expecting multiple true values for every point – helloworld Oct 10 '22 at 14:50
  • ok. In this case your problem is not to classify the points in 5000 classes. Your problem is to perform 5000 binary classifications. So you have to use binary cross entropy rather than multi-class cross entropy. Am I right? – Salvatore Daniele Bianco Oct 10 '22 at 14:57
  • Sorry for the confusion. In my case, i don't want to predict multiple classes at the same time. Rather, I'm looking for the model to predict any one/many of the ground truths – helloworld Oct 11 '22 at 05:33
  • Conceptually, a given sample in exactly one class (say, class 3), but for training purposes, predicting class 2 or 5 is still okay so the model isn't penalized too heavily. This explains the multiple 1s for a given example's ground truth. – helloworld Oct 11 '22 at 05:37
  • 1
    OK, but in this case simply you can't apply the cross-entropy loss. you should define a custom loss function based on this specific task. – Salvatore Daniele Bianco Oct 11 '22 at 08:22
  • Yeah. I'm stuck designing this custom loss in itself. There aren't a lot of resources for tasks similar to this – helloworld Oct 11 '22 at 08:50
  • I think the binary cross-entropy approach might work. – Salvatore Daniele Bianco Oct 12 '22 at 09:36
  • Won't BCE force the model to predict all classes? – helloworld Oct 12 '22 at 14:22
  • 1
    of course not. Conceptually it is like training 5000 separate binary classifiers optimized to predict the probability of each class independently from other classes. Try to train and over-fit a model in this way to see how it works. – Salvatore Daniele Bianco Oct 12 '22 at 15:08

1 Answers1

0

DATA

op = torch.rand((4,3,5))
gt = torch.tensor([
    [[0,1,1,0,0],[0,0,1,0,0],[1,1,0,0,1]],
    [[1,1,0,0,1],[0,0,0,1,0],[0,0,1,0,0]],
    [[0,0,1,0,0],[1,1,1,1,0],[1,1,0,0,1]],
    [[1,1,0,0,1],[1,1,0,0,1],[1,0,0,0,0]]
], dtype=torch.float)

Now, if your output is in [0,1] (if it is not please provide a Sigmoid activation at the end of your model) you can compute the binary cross-entropy losses (N_class values for each point of each element) in this way:

torch.nn.BCELoss(reduction="none")(op, gt)

You can finally compute the average loss for each element of batch as:

torch.nn.BCELoss(reduction="none")(op, gt).mean(dim=[-1,-2])

If it is not the solution you are looking for or it is not clear let me know.