8

maybe someone is able to help me here. I am trying to compute the cross entropy loss of a given output of my network

print output
Variable containing:
1.00000e-02 *
-2.2739  2.9964 -7.8353  7.4667  4.6921  0.1391  0.6118  5.2227  6.2540     
-7.3584
[torch.FloatTensor of size 1x10]

and the desired label, which is of the form

print lab
Variable containing:
x
[torch.FloatTensor of size 1]

where x is an integer between 0 and 9. According to the pytorch documentation (http://pytorch.org/docs/master/nn.html)

criterion = nn.CrossEntropyLoss()
loss = criterion(output, lab)

this should work, but unfortunately I get a weird error

TypeError: FloatClassNLLCriterion_updateOutput received an invalid combination of arguments - got (int, torch.FloatTensor, !torch.FloatTensor!, torch.FloatTensor, bool, NoneType, torch.FloatTensor, int), but expected (int state, torch.FloatTensor input, torch.LongTensor target, torch.FloatTensor output, bool sizeAverage, [torch.FloatTensor weights or None], torch.FloatTensor total_weight, int ignore_index)

Can anyone help me? I am really confused and tried almost everything I could imagined to be helpful.

Best

Elias E.
  • 101
  • 1
  • 1
  • 8

1 Answers1

7

Please check this code

import torch
import torch.nn as nn
from torch.autograd import Variable

output = Variable(torch.rand(1,10))
target = Variable(torch.LongTensor([1]))

criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)
print(loss)

This will print out the loss nicely:

Variable containing:
 2.4498
[torch.FloatTensor of size 1]
Sung Kim
  • 8,417
  • 9
  • 34
  • 42
  • Thanks, yes the problem was that the target variable has to be a Long tensor and was of type float in my code. Thank you! – Elias E. Nov 09 '17 at 21:07
  • Hi, what if I wanna implement binary cross entropy loss? how can I do that? thanks – Rishabh Agrahari Mar 17 '18 at 10:12
  • the shape of the `output` is `[1,10]` and `target` is `[1]`. Cross-entropy loss is `label * log (predicted)` for each class. So, during loss computation does Pytorch use the same target label (1 here) for each value in `output`? – Rishabh Agrahari Aug 08 '18 at 14:37