Transformer’s Scaled Dot-Product Attention
Feb 17, 2022
In the paper transformer, it propose the Scaled Dot-Product Attention:
it claims that
Why divide by a value can push the softmax output to regions not close to extreme values? Let’s do a quick experiments:
- Original values of
[10, 20, 30, …90]
- scale by 10, we have
[1, 2, 3, …9]
- no scale, we have
[10, 20, 30, …90]
Softmax outputs visualization
Indeed, it works, the reason went to the exponential function, as
e⁹ / e⁸ is much larger than e^{0.9} / e^{0.8}
Code used for the vis
import torch
from torch import nn
from matplotlib import pyplot as plt
f = nn.Softmax(dim=1)
x = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9], [0.1, 0.20, 0.3, 0.40, 0.50, 0.6, 0.7, 0.8, 0.9]], dtype=torch.float)
y = f(x)
plt.scatter(range(x.shape[1]), y[0], c="r", label="no scale")
plt.scatter(range(x.shape[1]), y[1], c="g", label="scale by 10")
plt.legend()
plt.show()