1

I am trying to figure out the input of the torch.gumbel_softmax, or just gumbel softmax in general. From its original paper it seems like the authors are using the normalized categorical log probability:

The Gumbel-Max trick (Gumbel, 1954; Maddison et al., 2014) provides a simple and efficient way to draw samples z from a categorical distribution with class probabilities π: z = one_hot(argmax_i[g_i + log π_i])

However, from some other posts (2,3) as well as from torch.nn.functional.gumbel_softmax doc they are saying the input should be the unnormalized category probability.

Empirically I think using unnormalized prob makes more sense because it seems more reasonable to add gumbel noise to unconstrained numbers instead of normalized probabilities, but I would like to get a more concrete answer.

I am aware there is a similar question, but it doesn't thoroughly solve my question.

Sammy Cui
  • 108
  • 1
  • 6

0 Answers0