I am trying to figure out the input of the torch.gumbel_softmax, or just gumbel softmax in general. From its original paper it seems like the authors are using the normalized categorical log probability:
The Gumbel-Max trick (Gumbel, 1954; Maddison et al., 2014) provides a simple and efficient way to draw samples z from a categorical distribution with class probabilities π: z = one_hot(argmax_i[g_i + log π_i])
However, from some other posts (2,3) as well as from torch.nn.functional.gumbel_softmax doc they are saying the input should be the unnormalized category probability.
Empirically I think using unnormalized prob makes more sense because it seems more reasonable to add gumbel noise to unconstrained numbers instead of normalized probabilities, but I would like to get a more concrete answer.
I am aware there is a similar question, but it doesn't thoroughly solve my question.