In pointer networks the output logits are over the length of the inputs. Working with such batches means padding the inputs to the maximum length of the batch inputs. Now, this is all fine till we have to compute loss. Currently what i am doing is :
logits = stabilize(logits(inputs)) #[batch, max_length]. subtract max(logits) to stabilize
masks = masks(inputs) #[batch, max_length]. 1 for actual inputs, 0 for padded locations
exp_logits = exp(logits)
exp_logits_masked = exp_logits*masks
probs = exp_logits_masked/sum(exp_logits_masked)
Now i use these probabilities to compute cross entropy
cross_entropy = sum_over_batches(probs[correct_class])
Can i do better than this? Any ideas on how it is done generally by guys dealing with pointer networks?
If i didnt have variable size inputs this all could be achieved using callable tf.nn.softmax_cross_entropy_with_logits
on logits and labels (which is highly optimized) but that in variable lengths would produce erroneous results as softmax computation has denominator larger by 1 for each padding in an input.