Source code for sgnlp.models.rumour_detection_twitter.modules.encoder.position_encoder

import numpy as np
import torch
import torch.nn as nn


[docs]class PositionEncoder(nn.Module): """Encodes the position information into vectors There are 2 pieces of information that goes into the encoded information: 1. Word Embedding 2. Position Embedding """ @staticmethod def pos_emb(pos, dim, d_model): return pos / np.power(10000, (2 * (dim // 2) / d_model)) @staticmethod def cal_pos_emb(pos, d_emb_dim, d_model): return [ PositionEncoder.pos_emb(pos, dim, d_model) for dim in range(d_emb_dim) ] # Calculates for each dim @staticmethod def get_position_embedding(max_index, d_model, d_emb_dim, requires_grad=False): # position_embedding --> [[sin(0),cos(1),sin(2),cos(3)...], [sin(0),cos(1),sin(2),cos(3)...] ...] position_embedding = np.array( [ PositionEncoder.cal_pos_emb(pos, d_emb_dim, d_model) for pos in range(max_index + 1) ] ) # Getting for each token position_embedding[:, 0::2] = np.sin( position_embedding[:, 0::2] ) # dim 2i --> Starts from 0 position_embedding[:, 1::2] = np.cos( position_embedding[:, 1::2] ) # dim 2i+1 --> starts from 1 position_embedding = torch.FloatTensor(position_embedding) position_embedding.requires_grad = requires_grad return position_embedding def __init__(self, config, max_index): super(PositionEncoder, self).__init__() self.config = config self.max_index = max_index # Including the index self.d_model = self.config.d_model self.d_emb_dim = self.config.emb_dim self.requires_grad = self.config.train_pos_emb self.freeze = True if self.requires_grad == False else False # <------------- Defining the position embedding -------------> self.pos_embd_weights = PositionEncoder.get_position_embedding( max_index=self.max_index, d_model=self.config.d_model, d_emb_dim=self.d_emb_dim, requires_grad=self.requires_grad, ) self.position_encoding = nn.Embedding.from_pretrained( self.pos_embd_weights, freeze=self.freeze )
[docs] def forward(self, src_seq): """Encodes input according to position (The position encoding is based on the time stamp) Arg: src_seq: Sequence to encode. Returns: encoded_pos_features: Encoded features. """ position_index = src_seq.long() encoded_pos_features = self.position_encoding(position_index) return encoded_pos_features