Source code for sgnlp.models.lsr.modules.encoder

import torch.nn as nn
import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


[docs]class Encoder(nn.Module): def __init__(self, input_size, hidden_size, dropout_embedding, dropout_encoder): super(Encoder, self).__init__() self.hidden_size = hidden_size self.encoder = nn.GRU(input_size, self.hidden_size, bidirectional=True) self.dropout_embedding = nn.Dropout(p=dropout_embedding) self.dropout_encoder = nn.Dropout(p=dropout_encoder)
[docs] def forward(self, seq, lens): batch_size = seq.shape[0] lens_sorted, lens_argsort = torch.sort(lens, 0, True) _, lens_argsort_argsort = torch.sort(lens_argsort, 0) seq_ = torch.index_select(seq, 0, lens_argsort) seq_embd = self.dropout_embedding(seq_) # can only work if lens are on a cpu tensor packed = pack_padded_sequence(seq_embd, lens_sorted.cpu(), batch_first=True) self.encoder.flatten_parameters() output, h = self.encoder(packed) output, _ = pad_packed_sequence(output, batch_first=True) output = output.contiguous() output = torch.index_select(output, 0, lens_argsort_argsort) # B x m x 2l # last hidden state h = h.permute(1, 0, 2).contiguous().view(batch_size, 1, -1) h = torch.index_select(h, 0, lens_argsort_argsort) output = self.dropout_encoder(output) h = self.dropout_encoder(h) return output, h