import torch.nn as nn
import torch
import torch.nn.functional as F
[docs]class SelfAttention(nn.Module):
def __init__(self, input_size):
super().__init__()
self.input_linear = nn.Linear(input_size, 1, bias=False)
self.dot_scale = nn.Parameter(
torch.Tensor(input_size).uniform_(1.0 / (input_size ** 0.5))
)
[docs] def forward(self, input, memory, mask):
input_dot = self.input_linear(input)
cross_dot = torch.bmm(
input * self.dot_scale, memory.permute(0, 2, 1).contiguous()
)
att = input_dot + cross_dot
att = att - 1e30 * (1 - mask[:, None])
weight_one = F.softmax(att, dim=-1)
output_one = torch.bmm(weight_one, memory)
return torch.cat([input, output_one], dim=-1)