3

I don't understand to what does the dim parameter applies in torch.nn.Softmax. There is a warning that tells me to use it and I set it to 1, but I don't understand what I am setting. Where is it being used in the formula:

Softmax(xi​)=exp(xi)/∑j​exp(xj​)​

There is no dim here, so to what does it apply?

gruszczy
  • 40,948
  • 31
  • 128
  • 181

1 Answers1

3

The Pytorch documentation on torch.nn.Softmax states: dim (int) – A dimension along which Softmax will be computed (so every slice along dim will sum to 1).

For example, if you have a matrix with two dimensions, you can choose whether you want to apply the softmax to the rows or the columns:

import torch 
import numpy as np

softmax0 = torch.nn.Softmax(dim=0) # Applies along columns
softmax1 = torch.nn.Softmax(dim=1) # Applies along rows 

v = np.array([[1,2,3],
              [4,5,6]])
v =  torch.from_numpy(v).float()

softmax0(v)
# Returns
#[[0.0474, 0.0474, 0.0474],
# [0.9526, 0.9526, 0.9526]])


softmax1(v)
# Returns
#[[0.0900, 0.2447, 0.6652],
# [0.0900, 0.2447, 0.6652]]

Note how for softmax0 the columns add to 1, and for softmax1 the rows add to 1.

Community
  • 1
  • 1
Max Crous
  • 411
  • 2
  • 11