So the main logic inside pyg.utils.softmax()
is as following:
N = maybe_num_nodes(index, num_nodes)
src_max = scatter(src, index, dim, dim_size=N, reduce='max')
src_max = src_max.index_select(dim, index)
out = (src - src_max).exp()
out_sum = scatter(out, index, dim, dim_size=N, reduce='sum')
out_sum = out_sum.index_select(dim, index)
But it confuses me why there's a scatter(reduce='max')
or scatter_max()
operation at the beginning. As we all know Softmax is defined as softmax=exp(ei)/sum0<=j<=Nexp(ej)
There is not an operation related with max I think.
no I cant understand this