Source code for sgnlp.models.rumour_detection_twitter.modules.layer.layer

import torch
import torch.nn as nn
import torch.nn.functional as F

import os
import numpy as np

from . import feedforward_network
from . import multi_head_attention

__author__ = "Serena Khoo"


[docs]class Layer(nn.Module): def __init__(self, config, d_model, n_head): super(Layer, self).__init__() self.config = config self.d_model = d_model self.n_head = n_head self.attn_network = multi_head_attention.MultiHeadAttention( config, d_model, n_head ) self.ffn = feedforward_network.FeedForwardNetwork(config)
[docs] def forward( self, query, key, val, key_structure=None, val_structure=None, attention_mask=None, ): self_atten_features, atten_values = self.attn_network( query, key, val, key_structure=key_structure, val_structure=val_structure, attention_mask=attention_mask, ) enc_output = self.ffn(self_atten_features) del self_atten_features torch.cuda.empty_cache() return enc_output, atten_values