Source code for sgnlp.models.rumour_detection_twitter.modules.transformer.transformer
import torch
import torch.nn as nn
from ..layer import layer
__author__ = "Serena Khoo"
[docs]class Transformer(nn.Module):
@staticmethod
def init_weights(layer):
if type(layer) == nn.Linear:
nn.init.xavier_normal_(layer.weight)
def __init__(self, config, n_layers, d_model, n_heads):
super(Transformer, self).__init__()
# <----------- Config ----------->
self.config = config
# <----------- Model dimensions ----------->
self.n_layers = n_layers
self.d_model = d_model
self.n_heads = n_heads
# <----------- Stack of Attention layers ----------->
self.input_stack = nn.ModuleList(
[layer.Layer(config, d_model, n_heads) for _ in range(n_layers)]
)
[docs] def forward(
self,
query,
key,
val,
key_structure=None,
val_structure=None,
attention_mask=None,
):
"""
This function takes in a sequence and apply MHA to it
"""
# Merge with the query at each layer
self_atten_output = query
del query
torch.cuda.empty_cache()
# Storing the attention weights at each layer
self_atten_weights_dict = {}
i = 1
# Passing through the MHA layers
for layer in self.input_stack:
self_atten_output, self_atten_weights = layer(
query=self_atten_output,
key=key,
val=val,
key_structure=key_structure,
val_structure=val_structure,
attention_mask=attention_mask,
)
self_atten_weights_dict[i] = self_atten_weights
i += 1
del self_atten_weights
torch.cuda.empty_cache()
return self_atten_output, self_atten_weights_dict
def __repr__(self):
return str(vars(self))