0

I want to create a custom loss function in Torch which is a modification of ClassNLLCriterion. Concretely, ClassNLLCriterion loss is:

loss(x, class) = -x[class]

I want to modify this to be:

loss(x, class) = -x[class]*K

where K is a function of the network input, NOT the network weights or network output. Thus K can be treated as a constant.

What is the easiest way of implementing this custom criterion? The updateOutput() function seems straightforward, but how do I modify the updateGradInput() function?

braindead
  • 97
  • 2
  • 7

1 Answers1

1

Basically your loss function L is a function of the input and the target. So you have

loss(input, target) = ClassNLLCriterion(input, target) * K

if I understand correctly your new loss. Then you want to implement updateGradInput which returns the derivative of your loss function with respect to the input, which is

updateGradInput[ClassNLLCriterion](input, target) * K + ClassNLLCriterion(input, target) * dK/dinput

Therefore you only have to compute the derivative of K wrt the input of the loss function (you did not give us the formula to compute K) and plug it into the previous line. Since your new loss function relies on ClassNLLCriterion you can use the updateGradInput and updateOutput of this loss function to calculate yours.

fonfonx
  • 1,475
  • 21
  • 30
  • So essentially I don't have to write a custom criterion. In my training code, I can simply do: `loss = ClassNLLCriterion:forward()*K` and then `grad = ClassNLLCriterion:backward()*K+loss*(dK/dinput)` Is this correct? – braindead Jun 02 '17 at 19:44
  • Yeah this is also possible – fonfonx Jun 02 '17 at 19:48
  • Awesome. Thanks! One other question, if K is simply a constant (not dependent on network parameters or input or output), how would your answer change in that case? – braindead Jun 02 '17 at 19:49
  • if K is simply a constant I don't see the point of using it... It would just multiply all loss values by the same constant. So you could simply use the `ClassNLLCriterion` – fonfonx Jun 02 '17 at 19:53