i'm trying to define the loss function of a two-class classification problem. However, the target label is not hard label 0,1, but a float number between 0~1.
torch.nn.CrossEntropy in Pytorch do not support soft label so i'm trying to write a cross entropy function by my self.
My function looks like this
def cross_entropy(self, pred, target):
loss = -torch.mean(torch.sum(target.flatten() * torch.log(pred.flatten())))
return loss
def step(self, batch: Any):
x, y = batch
logits = self.forward(x)
loss = self.criterion(logits, y)
preds = logits
# torch.argmax(logits, dim=1)
return loss, preds, y
however it does not work at all.
Can anyone give me a suggestion is there any mistake in my loss function?