0

According to the pytorch doc of nn.BCEWithLogitsLoss, pos_weight is an optional argument a that takes the weight of positive examples. I don't fully understand the statement "pos_weight > 1 increases recall and pos_weight < 1 increases precision" in that page. How do you guys understand this statement?

desertnaut
  • 57,590
  • 26
  • 140
  • 166
Yiwei Jiang
  • 126
  • 8

1 Answers1

5

The binary cross-entropy with logits loss (nn.BCEWithLogitsLoss, equivalent to F.binary_cross_entropy_with_logits) is a sigmoid layer (nn.Sigmoid) followed with a binary cross-entropy loss (nn.BCELoss). The general case assumes you are in a multi-label classification task i.e. a single input can be labeled with multiple classes. One common sub-case is to have a single class: the binary classification task. If you define q as your tensor of predicted classes and p the ground-truth [0,1] corresponding to the true probabilities for each class.

The explicit formulation for the binary cross-entropy would be:

z = torch.sigmoid(q)
loss = -(w_p*p*torch.log(z) + (1-p)*torch.log(1-z))

introducing the w_p, the weight associated with the true label for each class. Read this post for more details on the weighting scheme used by the BCELoss.

For a given class:

precision =  TP / (TP + FP)
recall = TP / (TP + FN)

Then if w_p > 1, it increases the weight on the positive classification (classifying as true). This will tend to increase false positives (FP), thus decreasing the precision. Similarly if if w_p < 1, we are decreasing the weight on the true class which means it will tend to increase false negatives (FN), which decreases recall.

Ivan
  • 34,531
  • 8
  • 55
  • 100