1

I'm trying to solve a binary classification problem (target=0 and target=1) with an exception: Some of my labels are classified as target=0.5 on purpose, and I wish to have zero loss for either classifying it as 0 or 1 (i.e both classes are "correct").

I tried to implement a custom loss from scratch, based on PyTorch's BCEWithLogitsLoss:

class myLoss(torch.nn.Module):

def __init__(self, pos_weight=1):
    super().__init__()
    self.pos_weight = pos_weight

def forward(self, input, target):
    epsilon = 10 ** -44
    my_bce_loss = -1 * (self.pos_weight * target * F.logsigmoid(input + epsilon)
                        + (1 - target) * log(1 - sigmoid(input) + epsilon))
    add_loss = (target - 0.5) ** 2 * 4
    mean_loss = (my_bce_loss * add_loss).mean()
    return mean_loss

epsilon was chosen so the log will be bounded to -100, as suggested in BCE loss.

However I'm still getting NaN errors, after several epochs:

Function 'LogBackward' returned nan values in its 0th output.

or

Function 'SigmoidBackward' returned nan values in its 0th output.

Any suggestions how can I correct my loss function? maybe by somehow inherit and modify forward function?


Update: The way I call my custom loss function:

y = batch[:, -1, :].to(self.device, dtype=torch.float32)
y_pred_batch = self.model(x)

LossFun = myLoss(self.pos_weight)
batch_result.loss = LossFun.forward(y_pred_batch, y)

I use Temporal Convolutional Network model, implemented as follows:

out = self.conv1(x)
out = self.chomp1(out)               
out = self.elu(out) 
out = self.dropout1(out)
res = x if self.downsample is None else self.downsample(x)
return self.tanh(out + res)
Shlomi Shmuel
  • 21
  • 1
  • 7

1 Answers1

2

Try it this way:

class myLoss(torch.nn.Module):

    def __init__(self, pos_weight=1):
      super().__init__()
      self.pos_weight = pos_weight

    def forward(self, input, target):
      epsilon = 10 ** -44
      input = input.sigmoid().clamp(epsilon, 1 - epsilon)

      my_bce_loss = -1 * (self.pos_weight * target * torch.log(input)
                          + (1 - target) * torch.log(1 - input))
      add_loss = (target - 0.5) ** 2 * 4
      mean_loss = (my_bce_loss * add_loss).mean()
      return mean_loss

To test I perform 1000 backwards:


target = torch.randint(high=2, size=(32,))
loss_fn = myLoss()
for i in range(1000):
  inp = torch.rand(1, 32, requires_grad=True)
  loss = loss_fn(inp, target)
  loss.backward()
  if torch.isnan(loss):
    print('Loss NaN')
  if torch.isnan(inp.grad).any():
    print('NaN')


All works nice.

Guillem
  • 2,376
  • 2
  • 18
  • 35