import torch
import torch.nn as nn
import torch.nn.functional as F
from ..transformer import transformer
from ..encoder import learned_position_encoder
__author__ = "Serena Khoo"
[docs]class PostModule(nn.Module):
@staticmethod
def init_weights(layer):
if type(layer) == nn.Linear:
nn.init.xavier_normal_(layer.weight)
def __init__(self, config):
super(PostModule, self).__init__()
# <----------- Config ----------->
self.config = config
# <----------- Key and val structure encoder ----------->
if self.config.include_key_structure:
self.key_structure_encoder = (
learned_position_encoder.LearnedPositionEncoder(
self.config, self.config.n_head
)
)
if self.config.include_val_structure:
self.val_structure_encoder = (
learned_position_encoder.LearnedPositionEncoder(
self.config, self.config.n_head
)
)
# <----------- Getting a transformer for each level (word & post level) ----------->
self.transformer_post = transformer.Transformer(
self.config,
self.config.n_mha_layers,
self.config.d_model,
self.config.n_head,
)
# <----------- Embedding the posts through n FC layers (post level) ----------->
if self.config.ff_post:
self.emb_layer_post = nn.ModuleList(
[
nn.Linear(self.config.emb_dim, self.config.emb_dim)
for _ in range(self.config.num_emb_layers)
]
)
# <----------- Fine Tunning layer after getting only the first post's embedding ----------->
self.fine_tune_layer = nn.Linear(self.config.emb_dim, self.config.emb_dim)
# <----------- Final layer to predict the output class (4 classes) (To map emb to classes) ----------->
self.final_layer_emb = nn.Sequential(
nn.Linear(self.config.emb_dim, self.config.num_classes),
nn.LogSoftmax(dim=1),
)
# <----------- To map each post vector to a scalar ----------->
self.condense_layer_post = nn.Linear(self.config.d_model, 1)
self.final_layer_posts = nn.Sequential(
nn.Linear(self.config.max_tweets, self.config.num_classes),
nn.LogSoftmax(dim=1),
)
# <----------- Droutput for regularization ----------->
self.dropout = nn.Dropout(p=self.config.dropout_rate, inplace=True)
# <----------- Initialization of weights ----------->
if self.config.ff_post:
self.emb_layer_post.apply(PostModule.init_weights)
self.fine_tune_layer.apply(PostModule.init_weights)
self.final_layer_emb.apply(PostModule.init_weights)
self.condense_layer_post.apply(PostModule.init_weights)
self.final_layer_posts.apply(PostModule.init_weights)
[docs] def forward(
self,
X_word,
time_delay,
batch_size,
num_posts,
emb_dim,
structure=None,
attention_mask=None,
):
# <----------- Encoding the structure ----------->
key_structure = None
val_structure = None
if self.config.include_key_structure:
key_structure = self.key_structure_encoder(structure)
if self.config.include_val_structure:
val_structure = self.val_structure_encoder(structure)
# <----------- Passing X_word with n number of FC layers (post level), set on/off in config ----------->
if self.config.ff_post:
for i in range(self.config.num_emb_layers):
X_word = self.emb_layer_post[i](X_word)
# <----------- Adding in time delay information ----------->
if self.config.include_time_interval:
X_word += time_delay
# <----------- Setting the query, key and val ----------->
query_post = X_word
key_post = X_word
val_post = X_word
del X_word
torch.cuda.empty_cache()
# <----------- Adding in time delay information ----------->
self.dropout(query_post)
self.dropout(key_post)
self.dropout(val_post)
# <----------- Passing through post level transformer (Not keeping the attention values for now) ----------->
self_atten_output_post, self_atten_weights_dict_post = self.transformer_post(
query_post,
key_post,
val_post,
key_structure=key_structure,
val_structure=val_structure,
attention_mask=attention_mask,
)
# <----------- Baseline: Getting the average embedding for the self attended output ----------->
if self.config.post_module_version == 0:
self_atten_output_post = F.adaptive_avg_pool1d(
self_atten_output_post.permute(0, 2, 1), 1
).squeeze(-1)
# # <----------- Doing dropout here ----------->
# self.dropout(self_atten_output_post)
# <----------- Getting the predictions ----------->
output = self.final_layer_emb(self_atten_output_post)
torch.cuda.empty_cache()
# <----------- Approach 1: Condensing the post features in to a vector with length max_tweets ----------->
if self.config.post_module_version == 1:
self_atten_output_post = self.condense_layer_post(
self_atten_output_post
).squeeze(-1)
# # <----------- Doing dropout here ----------->
# self.dropout(self_atten_output_post)
# <----------- Getting the predictions ----------->
output = self.final_layer_posts(self_atten_output_post)
torch.cuda.empty_cache()
# <----------- Approach 2: Just get the first vector (Vector of the source post) ----------->
if self.config.post_module_version == 2:
self_atten_output_post = self.fine_tune_layer(
self_atten_output_post[:, 0, :]
)
# # <----------- Doing dropout here ----------->
# self.dropout(self_atten_output_post)
# <----------- Getting the predictions ----------->
output = self.final_layer_emb(self_atten_output_post)
torch.cuda.empty_cache()
# <----------- Approach 3: Attention over the vector ----------->
if self.config.post_module_version == 3:
attention_mask += -1.0
attention_mask *= 100000.0
posts_attention_values = self.condense_layer_post(self_atten_output_post)
posts_attention_weights = F.softmax(
posts_attention_values.permute(0, 2, 1) + attention_mask.unsqueeze(-2),
dim=-1,
)
del posts_attention_values
torch.cuda.empty_cache()
self_atten_output_post = torch.matmul(
posts_attention_weights, self_atten_output_post
).squeeze(1)
# # <----------- Doing dropout here ----------->
# self.dropout(self_atten_output_post)
# <----------- Getting the predictions ----------->
output = self.final_layer_emb(self_atten_output_post)
del attention_mask
torch.cuda.empty_cache()
return output, posts_attention_weights, self_atten_weights_dict_post