Source code for sgnlp.models.lsr.modeling

from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedModel, BertModel

from .config import LsrConfig
from .modules.attention import SelfAttention
from .modules.encoder import Encoder
from .modules.reasoner import DynamicReasoner
from .modules.reasoner import StructInduction


[docs]@dataclass class LsrModelOutput: """ Output type of :class:`~sgnlp.models.lsr.modeling.LsrModel` Args: prediction (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, max_h_t_count, num_relations)`): Prediction scores for all head to tail entity combinations from the final layer. Note that the sigmoid function has not been applied at this point. loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when `labels` is provided ): Loss on relation prediction task. """ prediction: torch.FloatTensor loss: Optional[torch.FloatTensor] = None
[docs]class LsrPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = LsrConfig base_model_prefix = "lsr" def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0)
[docs]class LsrModel(LsrPreTrainedModel): """The Latent Structure Refinement Model performs relation classification on all pairs of entity clusters. This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Args: config (:class:`~sgnlp.models.lsr.config.LsrConfig`): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Use the :obj:`.from_pretrained` method to load the model weights. Example:: from sgnlp.models.lsr import LsrModel, LsrConfig # Method 1: Loading a default model config = LsrConfig() model = LsrModel(config) # Method 2: Loading from pretrained config = LsrConfig.from_pretrained('https://storage.googleapis.com/sgnlp/models/lsr/config.json') model = LsrModel.from_pretrained('https://storage.googleapis.com/sgnlp/models/lsr/pytorch_model.bin', config=config) """ def __init__(self, config: LsrConfig): super().__init__(config) self.config = config # Common self.dropout = nn.Dropout(config.dropout_rate) self.relu = nn.ReLU() # Document encoder layers if config.use_bert: self.bert = BertModel.from_pretrained("bert-base-uncased") bert_hidden_size = 768 self.linear_re = nn.Linear(bert_hidden_size, config.hidden_dim) else: self.word_emb = nn.Embedding( config.word_embedding_shape[0], config.word_embedding_shape[1] ) if not config.finetune_emb: self.word_emb.weight.requires_grad = False self.ner_emb = nn.Embedding(13, config.ner_dim, padding_idx=0) self.coref_embed = nn.Embedding( config.max_length, config.coref_dim, padding_idx=0 ) self.linear_re = nn.Linear(config.hidden_dim * 2, config.hidden_dim) input_size = ( config.word_embedding_shape[1] + config.coref_dim + config.ner_dim ) self.rnn_sent = Encoder( input_size, config.hidden_dim, config.dropout_emb, config.dropout_rate ) # Induce latent structure layers self.use_struct_att = config.use_struct_att if self.use_struct_att: self.struct_induction = StructInduction( config.hidden_dim // 2, config.hidden_dim, True ) self.dropout_gcn = nn.Dropout(config.dropout_gcn) self.use_reasoning_block = config.use_reasoning_block if self.use_reasoning_block: self.reasoner = nn.ModuleList() self.reasoner.append( DynamicReasoner( config.hidden_dim, config.reasoner_layer_sizes[0], self.dropout_gcn ) ) self.reasoner.append( DynamicReasoner( config.hidden_dim, config.reasoner_layer_sizes[1], self.dropout_gcn ) ) # Output layers self.dis_embed = nn.Embedding(20, config.distance_size, padding_idx=10) self.self_att = SelfAttention(config.hidden_dim) self.bili = torch.nn.Bilinear( config.hidden_dim + config.distance_size, config.hidden_dim + config.distance_size, config.hidden_dim, ) self.linear_output = nn.Linear(2 * config.hidden_dim, config.num_relations) self.init_weights() def load_pretrained_word_embedding(self, pretrained_word_embedding): self.word_emb.weight.data.copy_(torch.from_numpy(pretrained_word_embedding)) def doc_encoder(self, input_sent, context_seg): batch_size = context_seg.shape[0] docs_emb = [] # sentence embedding docs_len = [] sents_emb = [] for batch_no in range(batch_size): sent_list = [] sent_lens = [] sent_index = ( ((context_seg[batch_no] == 1).nonzero()).squeeze(-1).tolist() ) # array of start point for sentences in a document pre_index = 0 for i, index in enumerate(sent_index): if i != 0: if i == 1: sent_list.append(input_sent[batch_no][pre_index : index + 1]) sent_lens.append(index - pre_index + 1) else: sent_list.append( input_sent[batch_no][pre_index + 1 : index + 1] ) sent_lens.append(index - pre_index) pre_index = index sents = pad_sequence(sent_list).permute(1, 0, 2) sent_lens_t = torch.LongTensor(sent_lens).to(device=self.device) docs_len.append(sent_lens) sents_output, sent_emb = self.rnn_sent( sents, sent_lens_t ) # sentence embeddings for a document. doc_emb = None for i, (sen_len, emb) in enumerate(zip(sent_lens, sents_output)): if i == 0: doc_emb = emb[:sen_len] else: doc_emb = torch.cat([doc_emb, emb[:sen_len]], dim=0) docs_emb.append(doc_emb) sents_emb.append(sent_emb.squeeze(1)) docs_emb = pad_sequence(docs_emb).permute(1, 0, 2) # B * # sentence * Dimension sents_emb = pad_sequence(sents_emb).permute(1, 0, 2) return docs_emb, sents_emb
[docs] def forward( self, context_idxs, context_pos, context_ner, h_mapping, t_mapping, relation_mask, dis_h_2_t, dis_t_2_h, context_seg, node_position, entity_position, node_sent_num, all_node_num, entity_num_list, sdp_position, sdp_num_list, context_masks=None, context_starts=None, relation_multi_label=None, **kwargs ): # TODO: current kwargs are ignored, to allow preprocessing to pass in unnecessary arguments # TODO: Fix upstream preprocessing such that it is filtered out before passing in. """ Args: context_idxs (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_tokens_length)`): Token IDs. context_pos (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_tokens_length)`): Coref position IDs. context_ner (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_tokens_length)`): NER tag IDs. h_mapping (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, h_t_limit, max_tokens_length)`): Head entity position mapping. t_mapping (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, h_t_limit, max_tokens_length)`): Tail entity position mapping. relation_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, h_t_limit)`): Relation mask. 1 if relation exists in position else 0. dis_h_2_t (:obj:`torch.LongTensor` of shape :obj:`(batch_size, h_t_limit)`): Distance encoding from head to tail. dis_t_2_h (:obj:`torch.LongTensor` of shape :obj:`(batch_size, h_t_limit)`): Distance encoding from tail to head. context_seg (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_tokens_length)`): Start position of sentences in document. 1 to mark position is start of sentence else 0. node_position (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_node_number, max_tokens_length)`): Mention node position. entity_position (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_entity_number, max_tokens_length)`): Entity node position. An entity refers to all mentions referring to the same entity. node_sent_num (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_sent_num)`): Number of mention nodes in each sentence of a document. all_node_num (:obj:`torch.LongTensor` of shape :obj:`(1)`): Total number of nodes (mention + MDP) in a document. entity_num_list (:obj:`List[int]` of shape :obj:`(batch_size)`): Number of entity nodes in each document. sdp_position (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_entity_number, max_tokens_length)`): Meta dependency paths (MDP) node position. sdp_num_list (:obj:`List[int]` of shape :obj:`(batch_size)`): Number of MDP nodes in each document. context_masks (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_length)`, `optional`): Mask for padding tokens. Used by bert model only. context_starts (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_length)`, `optional`): Tensor indicating start of words. Used by bert model only. relation_multi_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size, h_t_limit, num_relations)`): Label for all possible head to tail entity relations. Returns: output (:class:`~sgnlp.models.lsr.modeling.LsrModelOutput`) """ # Step 1: Encode the document if self.config.use_bert: context_output = self.bert(context_idxs, attention_mask=context_masks)[0] context_output = [ layer[starts.nonzero().squeeze(1)] for layer, starts in zip(context_output, context_starts) ] context_output = pad_sequence( context_output, batch_first=True, padding_value=-1 ) context_output = torch.nn.functional.pad( context_output, (0, 0, 0, context_idxs.size(-1) - context_output.size(-2)), ) context_output = self.dropout(torch.relu(self.linear_re(context_output))) max_doc_len = 512 else: sent_emb = torch.cat( [ self.word_emb(context_idxs), self.coref_embed(context_pos), self.ner_emb(context_ner), ], dim=-1, ) docs_rep, sents_rep = self.doc_encoder(sent_emb, context_seg) max_doc_len = docs_rep.shape[1] context_output = self.dropout(torch.relu(self.linear_re(docs_rep))) # Step 2: Extract all node reps of a document graph # extract mention node representations mention_num_list = torch.sum(node_sent_num, dim=1).tolist() max_mention_num = max(mention_num_list) mentions_rep = torch.bmm( node_position[:, :max_mention_num, :max_doc_len], context_output ) # mentions rep # extract meta dependency paths (MDP) node representations max_sdp_num = max(sdp_num_list) sdp_rep = torch.bmm(sdp_position[:, :max_sdp_num, :max_doc_len], context_output) # extract entity node representations entity_rep = torch.bmm(entity_position[:, :, :max_doc_len], context_output) # concatenate all nodes of an instance gcn_inputs = [] all_node_num_batch = [] for batch_no, (m_n, e_n, s_n) in enumerate( zip(mention_num_list, entity_num_list, sdp_num_list) ): m_rep = mentions_rep[batch_no][:m_n] e_rep = entity_rep[batch_no][:e_n] s_rep = sdp_rep[batch_no][:s_n] gcn_inputs.append(torch.cat((m_rep, e_rep, s_rep), dim=0)) node_num = m_n + e_n + s_n all_node_num_batch.append(node_num) gcn_inputs = pad_sequence(gcn_inputs).permute(1, 0, 2) output = gcn_inputs # Step 3: Induce the Latent Structure if self.use_reasoning_block: for i in range(len(self.reasoner)): output = self.reasoner[i](output) elif self.use_struct_att: gcn_inputs, _ = self.struct_induction(gcn_inputs) max_all_node_num = torch.max(all_node_num).item() assert gcn_inputs.shape[1] == max_all_node_num node_position = node_position.permute(0, 2, 1) output = torch.bmm( node_position[:, :max_doc_len, :max_mention_num], output[:, :max_mention_num], ) context_output = torch.add(context_output, output) start_re_output = torch.matmul( h_mapping[:, :, :max_doc_len], context_output ) # aggregation end_re_output = torch.matmul( t_mapping[:, :, :max_doc_len], context_output ) # aggregation s_rep = torch.cat([start_re_output, self.dis_embed(dis_h_2_t)], dim=-1) t_rep = torch.cat([end_re_output, self.dis_embed(dis_t_2_h)], dim=-1) re_rep = self.dropout(self.relu(self.bili(s_rep, t_rep))) re_rep = self.self_att(re_rep, re_rep, relation_mask) prediction = self.linear_output(re_rep) loss = None if relation_multi_label is not None: loss_fn = nn.BCEWithLogitsLoss(reduction="none") loss = torch.sum( loss_fn(prediction, relation_multi_label) * relation_mask.unsqueeze(2) ) / torch.sum(relation_mask) return LsrModelOutput(prediction=prediction, loss=loss)