Source code for sgnlp.models.rumour_detection_twitter.modules.transformer.hierarchical_transformer

import torch
import torch.nn as nn
import torch.nn.functional as F
from ..submodules import word_module
from ..submodules import post_module

__author__ = "Serena Khoo"


[docs]class HierarchicalTransformer(nn.Module): @staticmethod def init_weights(layer): if type(layer) == nn.Linear: nn.init.xavier_normal_(layer.weight) def __init__(self, config): super(HierarchicalTransformer, self).__init__() # <----------- Config -----------> self.config = config # <----------- Both word and post modules -----------> self.word_module = word_module.WordModule(config) self.post_module = post_module.PostModule(config)
[docs] def forward( self, X, word_pos, time_delay, structure, attention_mask_word=None, attention_mask_post=None, return_attention=False, ): # <----------- Getting the dimensions -----------> batch_size, num_posts, num_words, emb_dim = X.shape # <----------- Passing through word module -----------> X_word, self_atten_weights_dict_word = self.word_module( X, word_pos, attention_mask=attention_mask_word ) # <----------- Passing through post module -----------> output, self_atten_output_post, self_atten_weights_dict_post = self.post_module( X_word, time_delay, batch_size, num_posts, emb_dim, structure=structure, attention_mask=attention_mask_post, ) # <--------- Clear the memory --------------> torch.cuda.empty_cache() if return_attention: return ( output, self_atten_output_post, self_atten_weights_dict_word, self_atten_weights_dict_post, ) # <-------- Delete the attention weights if not returning it ----------> del self_atten_weights_dict_word del self_atten_weights_dict_post del self_atten_output_post torch.cuda.empty_cache() return output