Source code for sgnlp.models.sentic_gcn.modules.dynamic_rnn

import torch
import torch.nn as nn


[docs]class DynamicLSTM(nn.Module): """ A dynamic LSTM class which can hold variable length sequence """ def __init__( self, input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, batch_first: bool = True, dropout: float = 0, bidirectional: bool = False, only_use_last_hidden_state: bool = False, rnn_type: str = "LSTM", ) -> None: super(DynamicLSTM, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.bias = bias self.batch_first = batch_first self.dropout = dropout self.bidirectional = bidirectional self.only_use_last_hidden_state = only_use_last_hidden_state self.rnn_type = rnn_type self.__init_rnn() def __init_rnn(self) -> None: """ Helper method to initalized RNN type """ input_args = { "input_size": self.input_size, "hidden_size": self.hidden_size, "num_layers": self.num_layers, "bias": self.bias, "batch_first": self.batch_first, "dropout": self.dropout, "bidirectional": self.bidirectional, } if self.rnn_type == "LSTM": self.rnn = nn.LSTM(**input_args) elif self.rnn_type == "GRU": self.rnn = nn.GRU(**input_args) elif self.rnn_type == "RNN": self.rnn = nn.RNN(**input_args)
[docs] def forward(self, x: torch.Tensor, x_len: torch.Tensor, h0: torch.Tensor = None) -> torch.Tensor: # Sort x_sort_idx = torch.argsort(-x_len) x_unsort_idx = torch.argsort(x_sort_idx).long() x_len = x_len[x_sort_idx] x = x[x_sort_idx.long()] # Pack x_emb_p = torch.nn.utils.rnn.pack_padded_sequence(x, x_len.cpu(), batch_first=self.batch_first) if self.rnn_type == "LSTM": out_pack, (ht, ct) = self.rnn(x_emb_p, None) if h0 is None else self.rnn(x_emb_p, (h0, h0)) else: out_pack, ht = self.rnn(x_emb_p, None) if h0 is None else self.rnn(x_emb_p, h0) ct = None # Unsort # (num_layers * num_directions, batch, hidden_size) -> (batch, ...) ht = torch.transpose(ht, 0, 1)[x_unsort_idx] ht = torch.transpose(ht, 0, 1) if self.only_use_last_hidden_state: return ht else: # Unpack: out out = torch.nn.utils.rnn.pad_packed_sequence(out_pack, batch_first=self.batch_first) # (sequence, lengths) out = out[0] out = out[x_unsort_idx] # Unsort: out c if self.rnn_type == "LSTM": # (num_layers * num_directions, batch, hidden_size) -> (batch, ...) ct = torch.transpose(ct, 0, 1)[x_unsort_idx] ct = torch.transpose(ct, 0, 1) return out, (ht, ct)