I'm implementing a Convolutional Neural Network in Tensorflow with python. I'm in the following scenario: I've got a tensor of labels y (batch labels) like this:
y = [[0,1,0]
[0,0,1]
[1,0,0]]
where each row is a one-hot vector that represents a label related to the correspondent example. Now in training I want stop loss gradient (set to 0) of the example with that label (the third):
[1,0,0]
which rappresents the n/a label, instead the loss of the other examples in the batch are computed. For my loss computation I use a method like that:
self.y_loss = kl_divergence(self.pred_y, self.y)
I found this function that stop gradient, but how can apply it to conditionally to the batch elements?