1

I'm trying to write some code like below:

x = Variable(torch.Tensor([[1.0,2.0,3.0]]))
y = Variable(torch.LongTensor([1]))
w = torch.Tensor([1.0,1.0,1.0])
F.cross_entropy(x,y,w)
w = torch.Tensor([1.0,10.0,1.0])
F.cross_entropy(x,y,w)

However, the output of cross entropy loss is always 1.4076 whatever w is. What is behind the weight parameter for F.cross_entropy()? How to use it correctly?
I'm using pytorch 0.3

konchy
  • 573
  • 5
  • 16

1 Answers1

3

The weight parameter is used to compute a weighted result for all inputs based on their target class. If you have only one input or all inputs of the same target class, weight won't impact the loss.

See the difference however with 2 inputs of different target classes:

import torch
import torch.nn.functional as F
from torch.autograd import Variable

x = Variable(torch.Tensor([[1.0,2.0,3.0], [1.0,2.0,3.0]]))
y = Variable(torch.LongTensor([1, 2]))
w = torch.Tensor([1.0,1.0,1.0])
res = F.cross_entropy(x,y,w)
# 0.9076
w = torch.Tensor([1.0,10.0,1.0])
res = F.cross_entropy(x,y,w)
# 1.3167
benjaminplanche
  • 14,689
  • 5
  • 57
  • 69
  • I did not find an actual expression that show how the weights are used and the c++ is hard for me to decipher, any clues where I can find these details? – pixelou Jul 30 '19 at 14:55
  • 1
    @pixelou: You can find the loss equation with `weights` in the `torch.nn.CrossEntropyLoss` [doc](https://pytorch.org/docs/stable/nn.html#torch.nn.CrossEntropyLoss). You have the python implementation [here](https://github.com/pytorch/pytorch/blob/22c169fb9cb618c660dab22edbe370c16a12a220/torch/nn/functional.py) otherwise. – benjaminplanche Jul 30 '19 at 16:12
  • My bad, I read this question too fast and though it was about the binary CE. You still answered my question though ;-) : in the python code you linked, it shows the latter uses point-wise weights instead of class-wise ones in pytorch. – pixelou Jul 31 '19 at 07:44
  • @pixelou: Glad it helped nevertheless! :) – benjaminplanche Jul 31 '19 at 08:49