Transformer’s Scaled Dot-Product Attention

Scaled Dot Product attention from [1]
[10, 20, 30, …90]
[1, 2, 3, …9]
[10, 20, 30, …90]

Code used for the vis

import torch
from torch import nn
from matplotlib import pyplot as plt
f = nn.Softmax(dim=1)
x = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9], [0.1, 0.20, 0.3, 0.40, 0.50, 0.6, 0.7, 0.8, 0.9]], dtype=torch.float)
y = f(x)
plt.scatter(range(x.shape[1]), y[0], c="r", label="no scale")
plt.scatter(range(x.shape[1]), y[1], c="g", label="scale by 10")
plt.legend()
plt.show()

Reference

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store