Source code for sgnlp.models.rst_pointer.modules.encoder_rnn

from typing import List, Tuple

import torch
import torch.nn as nn
from transformers.tokenization_utils_base import BatchEncoding


[docs]class EncoderRNN(nn.Module): """ EncoderRNN model to be used in the encoder of the RST Parser network. """ def __init__(self, word_dim, hidden_size, rnn_layers=6, dropout=0.2): super(EncoderRNN, self).__init__() self.word_dim = word_dim self.hidden_size = hidden_size self.rnn_layers = rnn_layers self.dropout_layer = nn.Dropout(dropout) self.batchnorm_input = nn.BatchNorm1d( word_dim, affine=False, track_running_stats=False ) self.gru = nn.GRU( word_dim, hidden_size, rnn_layers, batch_first=True, dropout=(0 if rnn_layers == 1 else dropout), bidirectional=True, )
[docs] def forward( self, input_embeddings: BatchEncoding, input_lengths: List[int] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass for encoder RNN. Args: input_embeddings (BatchEncoding): input embeddings from the RST preprocessor. input_lengths (List[int]): list of input lengths. Returns: Tuple[torch.Tensor, torch.Tensor]: return outputs and hidden states of encoder. """ # batch norm embeddings = input_embeddings["elmo_representations"][0] embeddings = embeddings.permute(0, 2, 1) embeddings = self.batchnorm_input(embeddings) embeddings = embeddings.permute(0, 2, 1) # apply dropout embeddings = self.dropout_layer(embeddings) # added enforce_sorted=False because input_lengths are not sorted. enfore_sorted=True is only reuqired for # ONNX export. Reference: https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pack_padded_sequence.html packed = nn.utils.rnn.pack_padded_sequence( embeddings, input_lengths, batch_first=True, enforce_sorted=False ) # initialize hidden state batch_size = embeddings.size(0) hidden_initial = self.init_hidden(batch_size) # feed-forward through GRU outputs, hidden = self.gru(packed, hidden_initial) outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) # apply dropout outputs = outputs.contiguous() outputs = self.dropout_layer(outputs) # sum bidirectional GRU outputs outputs = outputs[:, :, : self.hidden_size] + outputs[:, :, self.hidden_size :] # obtain last hidden state of encoder hidden = hidden.contiguous() hidden = hidden[: self.rnn_layers, :, :] + hidden[self.rnn_layers :, :, :] return outputs, hidden
def init_hidden(self, batch_size): h_0 = torch.zeros(2 * self.rnn_layers, batch_size, self.hidden_size) device = self.gru.all_weights[0][ 0 ].device # checks device that layer has been put on h_0 = h_0.to(device) return h_0