LsrModel

class LsrModel(config: sgnlp.models.lsr.config.LsrConfig)[source]

The Latent Structure Refinement Model performs relation classification on all pairs of entity clusters.

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

Parameters

config (LsrConfig) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Use the from_pretrained method to load the model weights.

Example:

from sgnlp.models.lsr import LsrModel, LsrConfig

# Method 1: Loading a default model
config = LsrConfig()
model = LsrModel(config)

# Method 2: Loading from pretrained
config = LsrConfig.from_pretrained('https://storage.googleapis.com/sgnlp/models/lsr/config.json')
model = LsrModel.from_pretrained('https://storage.googleapis.com/sgnlp/models/lsr/pytorch_model.bin',
                                 config=config)
forward(context_idxs, context_pos, context_ner, h_mapping, t_mapping, relation_mask, dis_h_2_t, dis_t_2_h, context_seg, node_position, entity_position, node_sent_num, all_node_num, entity_num_list, sdp_position, sdp_num_list, context_masks=None, context_starts=None, relation_multi_label=None, **kwargs)[source]
Parameters
  • context_idxs (torch.LongTensor of shape (batch_size, max_tokens_length)) – Token IDs.

  • context_pos (torch.LongTensor of shape (batch_size, max_tokens_length)) – Coref position IDs.

  • context_ner (torch.LongTensor of shape (batch_size, max_tokens_length)) – NER tag IDs.

  • h_mapping (torch.FloatTensor of shape (batch_size, h_t_limit, max_tokens_length)) – Head entity position mapping.

  • t_mapping (torch.FloatTensor of shape (batch_size, h_t_limit, max_tokens_length)) – Tail entity position mapping.

  • relation_mask (torch.FloatTensor of shape (batch_size, h_t_limit)) – Relation mask. 1 if relation exists in position else 0.

  • dis_h_2_t (torch.LongTensor of shape (batch_size, h_t_limit)) – Distance encoding from head to tail.

  • dis_t_2_h (torch.LongTensor of shape (batch_size, h_t_limit)) – Distance encoding from tail to head.

  • context_seg (torch.LongTensor of shape (batch_size, max_tokens_length)) – Start position of sentences in document. 1 to mark position is start of sentence else 0.

  • node_position (torch.LongTensor of shape (batch_size, max_node_number, max_tokens_length)) – Mention node position.

  • entity_position (torch.LongTensor of shape (batch_size, max_entity_number, max_tokens_length)) – Entity node position. An entity refers to all mentions referring to the same entity.

  • node_sent_num (torch.LongTensor of shape (batch_size, max_sent_num)) – Number of mention nodes in each sentence of a document.

  • all_node_num (torch.LongTensor of shape (1)) – Total number of nodes (mention + MDP) in a document.

  • entity_num_list (List[int] of shape (batch_size)) – Number of entity nodes in each document.

  • sdp_position (torch.LongTensor of shape (batch_size, max_entity_number, max_tokens_length)) – Meta dependency paths (MDP) node position.

  • sdp_num_list (List[int] of shape (batch_size)) – Number of MDP nodes in each document.

  • context_masks (torch.LongTensor of shape (batch_size, max_length), optional) – Mask for padding tokens. Used by bert model only.

  • context_starts (torch.LongTensor of shape (batch_size, max_length), optional) – Tensor indicating start of words. Used by bert model only.

  • relation_multi_label (torch.LongTensor of shape (batch_size, h_t_limit, num_relations)) – Label for all possible head to tail entity relations.

Returns

output (LsrModelOutput)