Source code for sgnlp.models.rumour_detection_twitter.modules.layer.attention

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

__author__ = "Serena Khoo"


[docs]class Attention(nn.Module): """ This class defines the dot-product attention """ def __init__(self, config, d_model, n_head): super(Attention, self).__init__() self.config = config self.d_model = d_model self.n_head = n_head
[docs] def forward( self, query, key, val, key_structure=None, val_structure=None, attention_mask=None, ): # <--------- Check that the last dim of the key is the same as d_model / n_head --------------> d_k = query.shape[-1] assert d_k == self.d_model / self.n_head # <--------- Calculating the attention, the attention is per head --------------> attention_values = torch.matmul( query, key.transpose(len(key.shape) - 2, len(key.shape) - 1) ) # QK' del key torch.cuda.empty_cache() # <-------- Adding the edge information --------------> if key_structure is not None: edge_score = torch.matmul( query.unsqueeze(3), key_structure.transpose(4, 3) ).squeeze(3) attention_values = attention_values + edge_score del edge_score del key_structure torch.cuda.empty_cache() del query torch.cuda.empty_cache() # <-------- Scaling of the attention values --------------> scaling_factor = np.power(d_k, 0.5) # d_k**0.5 # <-------- Apply attention masking if attention_mask is not None --------------> if attention_mask is not None: attention_mask = attention_mask.unsqueeze(1).unsqueeze( len(attention_mask.shape) ) attention_mask = (1.0 - attention_mask) * -100000.0 attention_values = attention_values + attention_mask del attention_mask torch.cuda.empty_cache() # <-------- Getting the softmax of the attention values --------------> attention_values = attention_values / scaling_factor attention_values = F.softmax(attention_values, dim=-1) # <-------- Getting the final output --------------> final_output = torch.matmul(attention_values, val) del val torch.cuda.empty_cache() # <-------- Getting the edge scores to be added to val --------------> if val_structure is not None: edge_val_score = torch.matmul( attention_values.unsqueeze(3), val_structure ).squeeze(3) final_output = final_output + edge_val_score del val_structure del edge_val_score torch.cuda.empty_cache() return final_output, attention_values