Softmax or LogSoftmax
2 min readJan 20, 2021
As a machine learning engineer, you should be pretty familiar to the softmax function.
The softmax function is pretty nice as it can normalize any value from [-inf, +inf] by applying an exponential function. However, the exponential function can be the evil as we can get super large value with small x for the e^x function. For example, when x is 19, we get e¹⁹ = 178482300. Pretty large. We may have a NaN issue.
>>> for x in range(20):print(x, math.e**x)
...
0 1.0
1 2.718281828459045
2 7.3890560989306495
3 20.085536923187664
4 54.59815003314423
5 148.41315910257657
6 403.428793492735
7 1096.6331584284583
8 2980.957987041727
9 8103.08392757538
10 22026.465794806703
11 59874.14171519778
12 162754.79141900383
13 442413.3920089202
14 1202604.2841647759
15 3269017.372472108
16 8886110.520507865
17 24154952.753575277
18 65659969.13733045
19 178482300.96318707
How to solve this problem?
“you can use
nn.LogSoftmax
, it is numerically more stable and is less likely to nan than usingSoftmax" From here
Why logsoftmax is more stable?
A nice discussion can be found here
An example to compare the output of Softmax and LogSoftmax
>>> m = nn.Softmax(dim=1)
>>> input = torch.randn(2, 3)
>>> output = m(input)
>>> input
tensor([[-1.1723, 0.3103, 1.7434],
[ 0.1054, 0.0876, 1.9890]])
>>> output
tensor([[0.0419, 0.1845, 0.7736],
[0.1168, 0.1148, 0.7684]])
>>> mm = nn.LogSoftmax(dim=1)
>>> output2 = mm(input)
>>> output2
tensor([[-3.1724, -1.6899, -0.2568],
[-2.1470, -2.1649, -0.2634]])
>>> torch.exp(output2)
tensor([[0.0419, 0.1845, 0.7736],
[0.1168, 0.1148, 0.7684]])
>>>