Source code for sgnlp.models.lsr.modules.attention

import torch.nn as nn
import torch
import torch.nn.functional as F


[docs]class SelfAttention(nn.Module): def __init__(self, input_size): super().__init__() self.input_linear = nn.Linear(input_size, 1, bias=False) self.dot_scale = nn.Parameter( torch.Tensor(input_size).uniform_(1.0 / (input_size ** 0.5)) )
[docs] def forward(self, input, memory, mask): input_dot = self.input_linear(input) cross_dot = torch.bmm( input * self.dot_scale, memory.permute(0, 2, 1).contiguous() ) att = input_dot + cross_dot att = att - 1e30 * (1 - mask[:, None]) weight_one = F.softmax(att, dim=-1) output_one = torch.bmm(weight_one, memory) return torch.cat([input, output_one], dim=-1)