3

i'm trying to define the loss function of a two-class classification problem. However, the target label is not hard label 0,1, but a float number between 0~1.

torch.nn.CrossEntropy in Pytorch do not support soft label so i'm trying to write a cross entropy function by my self.

My function looks like this

def cross_entropy(self, pred, target):
    loss = -torch.mean(torch.sum(target.flatten() * torch.log(pred.flatten())))
    return loss

def step(self, batch: Any):
    x, y = batch
    logits = self.forward(x)
    loss = self.criterion(logits, y)
    preds = logits
    # torch.argmax(logits, dim=1)
    return loss, preds, y

however it does not work at all.

Can anyone give me a suggestion is there any mistake in my loss function?

Dishin H Goyani
  • 7,195
  • 3
  • 26
  • 37

1 Answers1

2

It seems like BCELoss and the robust version BCEWithLogitsLoss are working with fuzzy targets "out of the box". They do not expect target to be binary" any number between zero and one is fine.
Please read the doc.

Shai
  • 111,146
  • 38
  • 238
  • 371