3

Both MXNet and PyTorch provide special implementation for computing log(softmax()), which is faster and numerically more stable. However, I cannot find the actual Python implementation for this function, log_softmax(), in either package.

Can anyone explain how this is implemented, or better, point me to the relevant source code?

desertnaut
  • 57,590
  • 26
  • 140
  • 166
herrlich10
  • 6,212
  • 5
  • 31
  • 35

2 Answers2

11
  • The numerical error:
>>> x = np.array([1, -10, 1000])
>>> np.exp(x) / np.exp(x).sum()
RuntimeWarning: overflow encountered in exp
RuntimeWarning: invalid value encountered in true_divide
Out[4]: array([ 0.,  0., nan])

There are 2 methods to avoid the numerical error while compute the softmax:

  • Exp Normalization:

enter image description here

def exp_normalize(x):
    b = x.max()
    y = np.exp(x - b)
    return y / y.sum()

>>> exp_normalize(x)
array([0., 0., 1.])
  • Log Sum Exp

enter image description here

def log_softmax(x):
    c = x.max()
    logsumexp = np.log(np.exp(x - c).sum())
    return x - c - logsumexp

Please note that, a reasonable choice for both b, c in above formula is max(x). With this choice, overflow due to exp is impossible. The largest number exponentiated after shifting is 0.

herrlich10
  • 6,212
  • 5
  • 31
  • 35
Jonny Vu
  • 1,420
  • 2
  • 13
  • 32
2

You can find one of the CPU implementations here and a vectorized version here (this is the log version, called from vec_host_softmax_lastdim).

You can find a CUDA implementation here, which then calls softmax_warp_forward.

They are all similar, just the syntax that differs. As you can see, there is usually a flag that defines whether or not softmax will be computed using the log., i.e., LogSoftMax instead of SoftMax.

Berriel
  • 12,659
  • 4
  • 43
  • 67
  • Thank you for pointing to the actual C++ implementation, which uses the same trick as what described in Jonny Vu's answer. – herrlich10 May 03 '20 at 14:38
  • @herrlich10 yeah, as I said, they are all identical, syntax is the only thing that differs. I thought you knew the trick and was just looking for the source-code :) I would've explained the trick as well. – Berriel May 03 '20 at 17:54