Source code for sgnlp.models.rst_pointer.modeling

from dataclasses import dataclass
from typing import List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.rnn as R
from torch.autograd import Variable
from transformers import PreTrainedModel
from transformers.file_utils import ModelOutput

from .config import RstPointerSegmenterConfig, RstPointerParserConfig
from .modules.classifier import LabelClassifier
from .modules.decoder_rnn import DecoderRNN
from .modules.elmo import initialize_elmo
from .modules.encoder_rnn import EncoderRNN
from .modules.pointer_attention import PointerAtten
from .modules.type import DiscourseTreeNode, DiscourseTreeSplit
from .utils import get_relation_and_nucleus


[docs]@dataclass class RstPointerSegmenterModelOutput(ModelOutput): loss: float = None start_boundaries: np.array = None end_boundaries: np.array = None
[docs]class RstPointerSegmenterPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = RstPointerSegmenterConfig base_model_prefix = "rst_pointer_segmenter" def _init_weights(self, module): pass
[docs]class RstPointerSegmenterModel(RstPointerSegmenterPreTrainedModel): """This model performs discourse segmentation. This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Args: config (:class:`~sgnlp.models.rst_pointer.RstPointerSegmenterConfig`): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Use the :obj:`.from_pretrained` method to load the model weights. Example:: from sgnlp.models.rst_pointer import RstPointerSegmenterConfig, RstPointerSegmenterModel # Method 1: Loading a default model segmenter_config = RstPointerSegmenterConfig() segmenter = RstPointerSegmenterModel(segmenter_config) # Method 2: Loading from pretrained segmenter_config = RstPointerSegmenterConfig.from_pretrained( 'https://storage.googleapis.com/sgnlp/models/rst_pointer/segmenter/config.json') segmenter = RstPointerSegmenterModel.from_pretrained( 'https://storage.googleapis.com/sgnlp/models/rst_pointer/segmenter/pytorch_model.bin', config=segmenter_config) """ def __init__(self, config: RstPointerSegmenterConfig): super().__init__(config) self.word_dim = config.word_dim self.hidden_dim = config.hidden_dim self.dropout_prob = config.dropout_prob self.use_bilstm = config.use_bilstm self.num_rnn_layers = config.num_rnn_layers self.rnn_type = config.rnn_type self.is_batch_norm = config.is_batch_norm self.dropout = nn.Dropout(config.dropout_prob) self.embedding, self.word_dim = initialize_elmo(config.elmo_size) if self.rnn_type in ["LSTM", "GRU"]: self.decoder_rnn = getattr(nn, self.rnn_type)( input_size=2 * self.hidden_dim if self.use_bilstm else self.hidden_dim, hidden_size=2 * self.hidden_dim if self.use_bilstm else self.hidden_dim, num_layers=self.num_rnn_layers, dropout=self.dropout_prob, batch_first=True, ) self.encoder_rnn = getattr(nn, self.rnn_type)( input_size=self.word_dim, hidden_size=self.hidden_dim, num_layers=self.num_rnn_layers, bidirectional=self.use_bilstm, dropout=self.dropout_prob, batch_first=True, ) else: raise ValueError("rnn_type should be one of ['LSTM', 'GRU'].") if self.use_bilstm: self.num_encoder_bi = 2 else: self.num_encoder_bi = 1 def init_hidden(self, hsize, batchsize): if self.rnn_type == "LSTM": h_0 = Variable( torch.zeros(self.num_encoder_bi * self.num_rnn_layers, batchsize, hsize) ).to(self.device) c_0 = Variable( torch.zeros(self.num_encoder_bi * self.num_rnn_layers, batchsize, hsize) ).to(self.device) return h_0, c_0 else: h_0 = Variable( torch.zeros(self.num_encoder_bi * self.num_rnn_layers, batchsize, hsize) ).to(self.device) return h_0 def _run_rnn_packed(self, cell, x, x_lens, h=None): # Sort first if ONNX exportability is needed x_packed = R.pack_padded_sequence( x, x_lens, batch_first=True, enforce_sorted=False ) if h is not None: output, h = cell(x_packed, h) else: output, h = cell(x_packed) output, _ = R.pad_packed_sequence(output, batch_first=True) return output, h def pointer_encoder(self, sentences_ids, sentences_lens): batch_norm = nn.BatchNorm1d( self.word_dim, affine=False, track_running_stats=False ) batch_size = len(sentences_ids) embeddings = self.embedding(sentences_ids) batch_x_elmo = embeddings["elmo_representations"][ 0 ] # two layers output [batch,length,d_elmo] x = batch_x_elmo if self.is_batch_norm: x = x.permute(0, 2, 1) # N C L x = batch_norm(x) x = x.permute(0, 2, 1) # N L C x = self.dropout(x) encoder_lstm_co_h_o = self.init_hidden(self.hidden_dim, batch_size) output_encoder, hidden_states_encoder = self._run_rnn_packed( self.encoder_rnn, x, sentences_lens, encoder_lstm_co_h_o ) # batch_first=True output_encoder = output_encoder.contiguous() output_encoder = self.dropout(output_encoder) return output_encoder, hidden_states_encoder def pointer_layer(self, encoder_states, cur_decoder_state): # we use simple dot product attention to computer pointer attention_pointer = torch.matmul(encoder_states, cur_decoder_state).unsqueeze(1) attention_pointer = attention_pointer.permute(1, 0) att_weights = F.softmax(attention_pointer, dim=1) logits = F.log_softmax(attention_pointer, dim=1) return logits, att_weights
[docs] def decoder(self, h_n, h_end, batch_x_lens, batch_y=None): """ Args: h_n: all hidden states h_end: final hidden state batch_x_lens: lengths of x (i.e. number of tokens) batch_y: optional. provide to get loss metric. Returns: A tuple containing the following values: batch_start_boundaries: array of start tokens for each predicted edu batch_end_boundaries: array of end tokens for each predicted edu batch_align_matrix: - batch_loss: optional metric loss. calculated if batch_y_index is provided. """ total_loops = 0 batch_start_boundaries = [] batch_end_boundaries = [] batch_align_matrix = [] batch_size = len(batch_x_lens) # calculate batch loss if y_index is provided if batch_y is not None: loss_function = nn.NLLLoss() batch_loss = 0 else: batch_loss = None for i in range(batch_size): cur_len = batch_x_lens[i] cur_encoder_hn = h_n[i, 0:cur_len, :] # length * encoder_hidden_size cur_end_boundary = cur_len - 1 # end boundary is index of last token cur_y_index = batch_y[i] if batch_y is not None else None cur_end_boundaries = [] cur_start_boundaries = [] cur_align_matrix = [] if self.rnn_type == "LSTM": # need h_end,c_end h_end = ( h_end[0] .permute(1, 0, 2) .contiguous() .view(batch_size, self.num_rnn_layers, -1) ) c_end = ( h_end[1] .permute(1, 0, 2) .contiguous() .view(batch_size, self.num_rnn_layers, -1) ) cur_h0 = h_end[i].unsqueeze(0).permute(1, 0, 2) cur_c0 = c_end[i].unsqueeze(0).permute(1, 0, 2) h_pass = (cur_h0, cur_c0) else: # only need h_end h_end = ( h_end.permute(1, 0, 2) .contiguous() .view(batch_size, self.num_rnn_layers, -1) ) cur_h0 = h_end[i].unsqueeze(0).permute(1, 0, 2) h_pass = cur_h0 loop_hc = h_pass loop_in = ( cur_encoder_hn[0, :].unsqueeze(0).unsqueeze(0) ) # [ 1, 1, encoder_hidden_size] (first start) cur_start_boundary = 0 loop_j = 0 while True: loop_o, loop_hc = self.decoder_rnn(loop_in, loop_hc) predict_range = list(range(cur_start_boundary, cur_len)) cur_encoder_hn_back = cur_encoder_hn[predict_range, :] cur_logits, cur_weights = self.pointer_layer( cur_encoder_hn_back, loop_o.squeeze(0).squeeze(0) ) cur_align_vector = np.zeros(cur_len) cur_align_vector[predict_range] = cur_weights.data.cpu().numpy()[0] cur_align_matrix.append(cur_align_vector) _, top_i = cur_logits.data.topk(1) pred_index = top_i[0][0].item() ori_pred_index = pred_index + cur_start_boundary # Calculate loss if batch_y is not None: if loop_j > len(cur_y_index) - 1: cur_ground_y = cur_y_index[-1] else: cur_ground_y = cur_y_index[loop_j] cur_ground_y_var = Variable( torch.LongTensor( [max(0, int(cur_ground_y) - cur_start_boundary)] ) ).to(self.device) batch_loss += loss_function(cur_logits, cur_ground_y_var) if cur_end_boundary <= ori_pred_index: cur_end_boundaries.append(cur_end_boundary) cur_start_boundaries.append(cur_start_boundary) total_loops = total_loops + 1 break else: cur_end_boundaries.append(ori_pred_index) loop_in = ( cur_encoder_hn[ori_pred_index + 1, :].unsqueeze(0).unsqueeze(0) ) cur_start_boundaries.append(cur_start_boundary) cur_start_boundary = ori_pred_index + 1 # start = pred_end + 1 loop_j = loop_j + 1 total_loops = total_loops + 1 # For each instance in batch batch_end_boundaries.append(cur_end_boundaries) batch_start_boundaries.append(cur_start_boundaries) batch_align_matrix.append(cur_align_matrix) batch_loss = batch_loss / total_loops if batch_y is not None else None return ( batch_start_boundaries, batch_end_boundaries, batch_align_matrix, batch_loss, )
[docs] def forward(self, tokenized_sentence_ids, sentence_lens, labels=None): """ Args: tokenized_sentence_ids: Token IDs. sentence_lens: Sentence lengths. labels: Optional. Provide if loss is needed. Returns: output (:class:`~sgnlp.models.rst_pointer.modeling.RstPointerSegmenterModelOutput`) """ encoder_h_n, encoder_h_end = self.pointer_encoder( tokenized_sentence_ids, sentence_lens ) start_boundaries, end_boundaries, _, loss = self.decoder( encoder_h_n, encoder_h_end, sentence_lens, labels ) return RstPointerSegmenterModelOutput(loss, start_boundaries, end_boundaries)
[docs]@dataclass class RstPointerParserModelOutput(ModelOutput): loss_tree: np.array = None loss_label: np.array = None splits: List[List[DiscourseTreeSplit]] = None
[docs]class RstPointerParserPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = RstPointerParserConfig base_model_prefix = "rst_pointer_parser" def _init_weights(self, module): pass
[docs]class RstPointerParserModel(RstPointerParserPreTrainedModel): """This model performs discourse parsing. This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Args: config (:class:`~sgnlp.models.rst_pointer.RstPointerParserConfig`): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Use the :obj:`.from_pretrained` method to load the model weights. Example:: from sgnlp.models.rst_pointer import RstPointerParserConfig, RstPointerParserModel # Method 1: Loading a default model parser_config = RstPointerParserConfig() parser = RstPointerParserModel(parser_config) # Method 2: Loading from pretrained parser_config = RstPointerParserConfig.from_pretrained( 'https://storage.googleapis.com/sgnlp/models/rst_pointer/parser/config.json') parser = RstPointerParserModel.from_pretrained( 'https://storage.googleapis.com/sgnlp/models/rst_pointer/parser/pytorch_model.bin', config=parser_config) """ def __init__(self, config: RstPointerParserConfig): super().__init__(config) self.word_dim = config.word_dim self.hidden_size = config.hidden_size self.decoder_input_size = config.decoder_input_size self.atten_model = config.atten_model self.classifier_input_size = config.classifier_input_size self.classifier_hidden_size = config.classifier_hidden_size self.highorder = config.highorder self.classes_label = config.classes_label self.classifier_bias = config.classifier_bias self.rnn_layers = config.rnn_layers self.embedding, self.word_dim = initialize_elmo(config.elmo_size) self.encoder = EncoderRNN( word_dim=self.word_dim, hidden_size=self.hidden_size, rnn_layers=self.rnn_layers, dropout=config.dropout_e, ) self.decoder = DecoderRNN( input_size=self.decoder_input_size, hidden_size=self.hidden_size, rnn_layers=self.rnn_layers, dropout=config.dropout_d, ) self.pointer = PointerAtten( atten_model=self.atten_model, hidden_size=self.hidden_size ) self.classifier = LabelClassifier( input_size=self.classifier_input_size, classifier_hidden_size=self.classifier_hidden_size, classes_label=self.classes_label, bias=self.classifier_bias, dropout=config.dropout_c, )
[docs] def forward( self, input_sentence_ids, edu_breaks, sentence_lengths, label_index=None, parsing_index=None, generate_splits=True, ): """ Args: input_sentence_ids: Input sentence IDs. edu_breaks: Token positions of edu breaks. sentence_lengths: Lengths of sentences. label_index: Label IDs. Needed only if loss needs to be computed. parsing_index: Parsing IDs. Needed only if loss needs to be computed. generate_splits: Whether to return splits. Returns: output (:class:`~sgnlp.models.rst_pointer.modeling.RstPointerParserModelOutput`) """ # Obtain encoder outputs and last hidden states embeddings = self.embedding(input_sentence_ids) encoder_outputs, last_hidden_states = self.encoder(embeddings, sentence_lengths) loss_function = nn.NLLLoss() loss_label_batch = 0 loss_tree_batch = 0 loop_label_batch = 0 loop_tree_batch = 0 cur_label = [] label_batch = [] cur_tree = [] tree_batch = [] if generate_splits: splits_batch = [] calculate_loss = True if (label_index and parsing_index) else False for i in range(len(edu_breaks)): if calculate_loss: cur_label_index = label_index[i] cur_label_index = torch.tensor(cur_label_index) cur_label_index = cur_label_index.to(self.device) cur_parsing_index = parsing_index[i] if len(edu_breaks[i]) == 1: # For a sentence containing only ONE EDU, it has no # corresponding relation label and parsing tree break. tree_batch.append([]) label_batch.append([]) if generate_splits: splits_batch.append([]) elif len(edu_breaks[i]) == 2: # Take the last hidden state of an EDU as the representation of # this EDU. The dimension: [2,hidden_size] cur_encoder_outputs = encoder_outputs[i][edu_breaks[i]] # Directly run the classifier to obain predicted label input_left = cur_encoder_outputs[0].unsqueeze(0) input_right = cur_encoder_outputs[1].unsqueeze(0) relation_weights, log_relation_weights = self.classifier( input_left, input_right ) _, topindex = relation_weights.topk(1) label_predict = int(topindex[0][0]) tree_batch.append([0]) label_batch.append([label_predict]) if calculate_loss: loss_label_batch = loss_label_batch + loss_function( log_relation_weights, cur_label_index ) loop_label_batch = loop_label_batch + 1 if generate_splits: ( nuclearity_left, nuclearity_right, relation_left, relation_right, ) = get_relation_and_nucleus(label_predict) split = DiscourseTreeSplit( left=DiscourseTreeNode( span=(0, 0), ns_type=nuclearity_left, label=relation_left ), right=DiscourseTreeNode( span=(1, 1), ns_type=nuclearity_right, label=relation_right ), ) splits_batch.append([split]) else: # Take the last hidden state of an EDU as the representation of this EDU # The dimension: [NO_EDU,hidden_size] cur_encoder_outputs = encoder_outputs[i][edu_breaks[i]] edu_index = [x for x in range(len(cur_encoder_outputs))] stacks = ["__StackRoot__", edu_index] # cur_decoder_input: [1,1,hidden_size] # Alternative way is to take the last one as the input. You need to prepare data accordingly for training cur_decoder_input = cur_encoder_outputs[0].unsqueeze(0).unsqueeze(0) # Obtain last hidden state temptest = torch.transpose(last_hidden_states, 0, 1)[i].unsqueeze(0) cur_last_hidden_states = torch.transpose(temptest, 0, 1) cur_last_hidden_states = cur_last_hidden_states.contiguous() cur_decoder_hidden = cur_last_hidden_states loop_index = 0 if generate_splits: splits = [] if self.highorder: cur_sibling = {} while stacks[-1] != "__StackRoot__": stack_head = stacks[-1] if len(stack_head) < 3: # Predict relation label input_left = cur_encoder_outputs[stack_head[0]].unsqueeze(0) input_right = cur_encoder_outputs[stack_head[-1]].unsqueeze(0) relation_weights, log_relation_weights = self.classifier( input_left, input_right ) _, topindex = relation_weights.topk(1) label_predict = int(topindex[0][0]) cur_label.append(label_predict) # For 2 EDU case, we directly point the first EDU # as the current parsing tree break cur_tree.append(stack_head[0]) # To keep decoder hidden states consistent _, cur_decoder_hidden = self.decoder( cur_decoder_input, cur_decoder_hidden ) # Align ground truth label if calculate_loss: if loop_index > (len(cur_parsing_index) - 1): cur_label_true = cur_label_index[-1] else: cur_label_true = cur_label_index[loop_index] loss_label_batch = loss_label_batch + loss_function( log_relation_weights, cur_label_true.unsqueeze(0) ) loop_label_batch = loop_label_batch + 1 loop_index = loop_index + 1 del stacks[-1] if generate_splits: # To generate a tree structure ( nuclearity_left, nuclearity_right, relation_left, relation_right, ) = get_relation_and_nucleus(label_predict) cur_split = DiscourseTreeSplit( left=DiscourseTreeNode( span=(stack_head[0], stack_head[0]), ns_type=nuclearity_left, label=relation_left, ), right=DiscourseTreeNode( span=(stack_head[-1], stack_head[-1]), ns_type=nuclearity_right, label=relation_right, ), ) splits.append(cur_split) else: # Length of stack_head >= 3 # Alternative way is to take the last one as the input. You need to prepare data accordingly for training cur_decoder_input = ( cur_encoder_outputs[stack_head[0]].unsqueeze(0).unsqueeze(0) ) if self.highorder: if loop_index != 0: # Incorporate Parents information cur_decoder_input_P = cur_encoder_outputs[ stack_head[-1] ] # To incorporate Sibling information if str(stack_head) in cur_sibling.keys(): cur_decoder_input_S = cur_encoder_outputs[ cur_sibling[str(stack_head)] ] inputs_all = torch.cat( ( cur_decoder_input.squeeze(0), cur_decoder_input_S.unsqueeze(0), cur_decoder_input_P.unsqueeze(0), ), 0, ) new_inputs_all = torch.matmul( F.softmax( torch.matmul( inputs_all, inputs_all.transpose(0, 1) ), 0, ), inputs_all, ) cur_decoder_input = ( new_inputs_all[0, :] + new_inputs_all[1, :] + new_inputs_all[2, :] ) cur_decoder_input = cur_decoder_input.unsqueeze( 0 ).unsqueeze(0) # cur_decoder_input = cur_decoder_input + cur_decoder_input_P + cur_decoder_input_S else: inputs_all = torch.cat( ( cur_decoder_input.squeeze(0), cur_decoder_input_P.unsqueeze(0), ), 0, ) new_inputs_all = torch.matmul( F.softmax( torch.matmul( inputs_all, inputs_all.transpose(0, 1) ), 0, ), inputs_all, ) cur_decoder_input = ( new_inputs_all[0, :] + new_inputs_all[1, :] ) cur_decoder_input = cur_decoder_input.unsqueeze( 0 ).unsqueeze(0) # Predict the parsing tree break cur_decoder_output, cur_decoder_hidden = self.decoder( cur_decoder_input, cur_decoder_hidden ) atten_weights, log_atten_weights = self.pointer( cur_encoder_outputs[stack_head[:-1]], cur_decoder_output.squeeze(0).squeeze(0), ) _, topindex_tree = atten_weights.topk(1) tree_predict = int(topindex_tree[0][0]) + stack_head[0] cur_tree.append(tree_predict) # Predict the Label input_left = cur_encoder_outputs[tree_predict].unsqueeze(0) input_right = cur_encoder_outputs[stack_head[-1]].unsqueeze(0) relation_weights, log_relation_weights = self.classifier( input_left, input_right ) _, topindex_label = relation_weights.topk(1) label_predict = int(topindex_label[0][0]) cur_label.append(label_predict) # Align ground true label and tree if calculate_loss: if loop_index > (len(cur_parsing_index) - 1): cur_label_true = cur_label_index[-1] cur_tree_true = cur_parsing_index[-1] else: cur_label_true = cur_label_index[loop_index] cur_tree_true = cur_parsing_index[loop_index] temp_ground = max( 0, (int(cur_tree_true) - int(stack_head[0])) ) if temp_ground >= (len(stack_head) - 1): temp_ground = stack_head[-2] - stack_head[0] # Compute Tree Loss cur_ground_index = torch.tensor([temp_ground]) cur_ground_index = cur_ground_index.to(self.device) loss_tree_batch = loss_tree_batch + loss_function( log_atten_weights, cur_ground_index ) # Compute Classifier Loss loss_label_batch = loss_label_batch + loss_function( log_relation_weights, cur_label_true.unsqueeze(0) ) loop_label_batch = loop_label_batch + 1 loop_tree_batch = loop_tree_batch + 1 # Stacks stuff stack_down = stack_head[(tree_predict - stack_head[0] + 1) :] stack_top = stack_head[: (tree_predict - stack_head[0] + 1)] del stacks[-1] loop_index = loop_index + 1 # Sibling information if self.highorder: if len(stack_down) > 2: cur_sibling.update({str(stack_down): stack_top[-1]}) # Remove ONE-EDU part if len(stack_down) > 1: stacks.append(stack_down) if len(stack_top) > 1: stacks.append(stack_top) if generate_splits: ( nuclearity_left, nuclearity_right, relation_left, relation_right, ) = get_relation_and_nucleus(label_predict) cur_split = DiscourseTreeSplit( left=DiscourseTreeNode( span=(stack_head[0], tree_predict), ns_type=nuclearity_left, label=relation_left, ), right=DiscourseTreeNode( span=(tree_predict + 1, stack_head[-1]), ns_type=nuclearity_right, label=relation_right, ), ) splits.append(cur_split) tree_batch.append(cur_tree) label_batch.append(cur_label) if generate_splits: splits_batch.append(splits) if calculate_loss: if loop_label_batch != 0: loss_label_batch = loss_label_batch / loop_label_batch loss_label_batch = loss_label_batch.detach().cpu().numpy() if loss_tree_batch != 0: loss_tree_batch = loss_tree_batch / loop_tree_batch loss_tree_batch = loss_tree_batch.detach().cpu().numpy() else: loss_tree_batch = None loss_label_batch = None return RstPointerParserModelOutput( loss_tree_batch, loss_label_batch, (splits_batch if generate_splits else None), )
def forward_train( self, input_sentence_ids_batch, edu_breaks_batch, label_index_batch, parsing_index_batch, decoder_input_index_batch, parents_index_batch, sibling_index_batch, sentence_lengths, ): # TODO: This function should ideally be combined with the forward function. # There are significant overlap in code, but also some significant differences in the logic # which makes refactoring them difficult. # To retain the original code's fidelity, this function is used for the training forward pass. # Obtain encoder outputs and last hidden states embeddings = self.embedding(input_sentence_ids_batch) encoder_outputs, last_hiddenstates = self.encoder(embeddings, sentence_lengths) loss_function = nn.NLLLoss() loss_label_batch = 0 loss_tree_batch = 0 loop_label_batch = 0 loop_tree_batch = 0 for i in range(input_sentence_ids_batch.shape[0]): cur_label_index = label_index_batch[i] cur_label_index = torch.tensor(cur_label_index) cur_label_index = cur_label_index.to(self.device) cur_parsing_index = parsing_index_batch[i] cur_decoder_input_index = decoder_input_index_batch[i] cur_parents_index = parents_index_batch[i] cur_sibling_index = sibling_index_batch[i] if len(edu_breaks_batch[i]) == 1: continue elif len(edu_breaks_batch[i]) == 2: # Take the last hidden state of an EDU as the representation of # this EDU. The dimension: [2,hidden_size] cur_encoder_outputs = encoder_outputs[i][edu_breaks_batch[i]] # Use the last hidden state of a span to predict the relation # beween these two span. input_left = cur_encoder_outputs[0].unsqueeze(0) input_right = cur_encoder_outputs[1].unsqueeze(0) _, log_relation_weights = self.classifier(input_left, input_right) loss_label_batch = loss_label_batch + loss_function( log_relation_weights, cur_label_index ) loop_label_batch = loop_label_batch + 1 else: # Take the last hidden state of an EDU as the representation of this EDU # The dimension: [NO_EDU,hidden_size] cur_encoder_outputs = encoder_outputs[i][edu_breaks_batch[i]].to( self.device ) # Obtain last hidden state of encoder temp = torch.transpose(last_hiddenstates, 0, 1)[i].unsqueeze(0) cur_last_hiddenstates = torch.transpose(temp, 0, 1) cur_last_hiddenstates = cur_last_hiddenstates.contiguous() if self.highorder: # Incorporate parents information cur_decoder_inputs_P = cur_encoder_outputs[cur_parents_index] cur_decoder_inputs_P[0] = 0 # Incorporate sibling information cur_decoder_inputs_S = torch.zeros( [len(cur_sibling_index), cur_encoder_outputs.shape[1]] ).to(self.device) for n, s_idx in enumerate(cur_sibling_index): if s_idx != 99: cur_decoder_inputs_S[n] = cur_encoder_outputs[s_idx] # Original input cur_decoder_inputs = cur_encoder_outputs[cur_decoder_input_index] # One-layer self attention inputs_all = torch.cat( ( cur_decoder_inputs.unsqueeze(0).transpose(0, 1), cur_decoder_inputs_S.unsqueeze(0).transpose(0, 1), cur_decoder_inputs_P.unsqueeze(0).transpose(0, 1), ), 1, ) new_inputs_all = torch.matmul( F.softmax( torch.matmul(inputs_all, inputs_all.transpose(1, 2)), 1 ), inputs_all, ) cur_decoder_inputs = ( new_inputs_all[:, 0, :] + new_inputs_all[:, 1, :] + new_inputs_all[:, 2, :] ) else: cur_decoder_inputs = cur_encoder_outputs[cur_decoder_input_index] # Obtain decoder outputs cur_decoder_outputs, _ = self.decoder( cur_decoder_inputs.unsqueeze(0), cur_last_hiddenstates ) cur_decoder_outputs = cur_decoder_outputs.squeeze(0) edu_index = [x for x in range(len(cur_encoder_outputs))] stacks = ["__StackRoot__", edu_index] for j in range(len(cur_decoder_outputs)): if stacks[-1] != "__StackRoot__": stack_head = stacks[-1] if len(stack_head) < 3: # We remove this from stacks after compute the # relation between these two EDUS # Compute Classifier Loss input_left = cur_encoder_outputs[ cur_parsing_index[j] ].unsqueeze(0) input_right = cur_encoder_outputs[stack_head[-1]].unsqueeze( 0 ) _, log_relation_weights = self.classifier( input_left, input_right ) loss_label_batch = loss_label_batch + loss_function( log_relation_weights, cur_label_index[j].unsqueeze(0) ) del stacks[-1] loop_label_batch = loop_label_batch + 1 else: # Length of stack_head >= 3 # Compute Tree Loss # We don't attend to the last EDU of a span to be parsed _, log_atten_weights = self.pointer( cur_encoder_outputs[stack_head[:-1]], cur_decoder_outputs[j], ) cur_ground_index = torch.tensor( [int(cur_parsing_index[j]) - int(stack_head[0])] ) cur_ground_index = cur_ground_index.to(self.device) loss_tree_batch = loss_tree_batch + loss_function( log_atten_weights, cur_ground_index ) # Compute Classifier Loss input_left = cur_encoder_outputs[ cur_parsing_index[j] ].unsqueeze(0) input_right = cur_encoder_outputs[stack_head[-1]].unsqueeze( 0 ) _, log_relation_weights = self.classifier( input_left, input_right ) loss_label_batch = loss_label_batch + loss_function( log_relation_weights, cur_label_index[j].unsqueeze(0) ) # Stacks stuff stack_down = stack_head[ (cur_parsing_index[j] - stack_head[0] + 1) : ] stack_top = stack_head[ : (cur_parsing_index[j] - stack_head[0] + 1) ] del stacks[-1] loop_label_batch = loop_label_batch + 1 loop_tree_batch = loop_tree_batch + 1 # Remove ONE-EDU part, TWO-EDU span will be removed after classifier in next step if len(stack_down) > 1: stacks.append(stack_down) if len(stack_top) > 1: stacks.append(stack_top) if loop_label_batch != 0: loss_label_batch = loss_label_batch / loop_label_batch if loss_tree_batch != 0: loss_tree_batch = loss_tree_batch / loop_tree_batch return loss_tree_batch, loss_label_batch