import numpy as np
import torch
import torch.nn as nn
[docs]class WordEncoder(nn.Module):
"""Encodes the information into vectors
There are 2 pieces of information that goes into the encoded information:
1. Word Embedding
2. Position Embedding
"""
def __init__(self, config):
# modified the original script to remove the need for the data loader
super(WordEncoder, self).__init__()
self.config = config
self.vocab_size = config.max_vocab
# <------------- Defining the word embedding dimensions ------------->
self.embedding_dim = self.config.emb_dim
# <------------- Loadings the pretrained embedding weights ------------->
self.emb = nn.Embedding(self.vocab_size, self.embedding_dim)
self.emb.weight.requires_grad = (
self.config.train_word_emb
) # make embedding non trainable
def load_pretrained_embedding(self, pretrained_embedding_path):
pretrained_embedding = torch.from_numpy(np.load(pretrained_embedding_path))
self.emb.weight.data.copy_(pretrained_embedding)
[docs] def forward(self, token_ids):
"""Encodes input using word embedding.
Args:
token_ids : LongTensor shape of (batch_size, num_posts, num_words)
Returns:
encoded_we_features : Tensor with shape of (batch_size, num_posts, num_words, emb_dim)
"""
encoded_we_features = self.emb(token_ids)
return encoded_we_features