LsrModel¶
-
class
LsrModel
(config: sgnlp.models.lsr.config.LsrConfig)[source]¶ The Latent Structure Refinement Model performs relation classification on all pairs of entity clusters.
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
- Parameters
config (
LsrConfig
) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Use thefrom_pretrained
method to load the model weights.
Example:
from sgnlp.models.lsr import LsrModel, LsrConfig # Method 1: Loading a default model config = LsrConfig() model = LsrModel(config) # Method 2: Loading from pretrained config = LsrConfig.from_pretrained('https://storage.googleapis.com/sgnlp/models/lsr/config.json') model = LsrModel.from_pretrained('https://storage.googleapis.com/sgnlp/models/lsr/pytorch_model.bin', config=config)
-
forward
(context_idxs, context_pos, context_ner, h_mapping, t_mapping, relation_mask, dis_h_2_t, dis_t_2_h, context_seg, node_position, entity_position, node_sent_num, all_node_num, entity_num_list, sdp_position, sdp_num_list, context_masks=None, context_starts=None, relation_multi_label=None, **kwargs)[source]¶ - Parameters
context_idxs (
torch.LongTensor
of shape(batch_size, max_tokens_length)
) – Token IDs.context_pos (
torch.LongTensor
of shape(batch_size, max_tokens_length)
) – Coref position IDs.context_ner (
torch.LongTensor
of shape(batch_size, max_tokens_length)
) – NER tag IDs.h_mapping (
torch.FloatTensor
of shape(batch_size, h_t_limit, max_tokens_length)
) – Head entity position mapping.t_mapping (
torch.FloatTensor
of shape(batch_size, h_t_limit, max_tokens_length)
) – Tail entity position mapping.relation_mask (
torch.FloatTensor
of shape(batch_size, h_t_limit)
) – Relation mask. 1 if relation exists in position else 0.dis_h_2_t (
torch.LongTensor
of shape(batch_size, h_t_limit)
) – Distance encoding from head to tail.dis_t_2_h (
torch.LongTensor
of shape(batch_size, h_t_limit)
) – Distance encoding from tail to head.context_seg (
torch.LongTensor
of shape(batch_size, max_tokens_length)
) – Start position of sentences in document. 1 to mark position is start of sentence else 0.node_position (
torch.LongTensor
of shape(batch_size, max_node_number, max_tokens_length)
) – Mention node position.entity_position (
torch.LongTensor
of shape(batch_size, max_entity_number, max_tokens_length)
) – Entity node position. An entity refers to all mentions referring to the same entity.node_sent_num (
torch.LongTensor
of shape(batch_size, max_sent_num)
) – Number of mention nodes in each sentence of a document.all_node_num (
torch.LongTensor
of shape(1)
) – Total number of nodes (mention + MDP) in a document.entity_num_list (
List[int]
of shape(batch_size)
) – Number of entity nodes in each document.sdp_position (
torch.LongTensor
of shape(batch_size, max_entity_number, max_tokens_length)
) – Meta dependency paths (MDP) node position.sdp_num_list (
List[int]
of shape(batch_size)
) – Number of MDP nodes in each document.context_masks (
torch.LongTensor
of shape(batch_size, max_length)
, optional) – Mask for padding tokens. Used by bert model only.context_starts (
torch.LongTensor
of shape(batch_size, max_length)
, optional) – Tensor indicating start of words. Used by bert model only.relation_multi_label (
torch.LongTensor
of shape(batch_size, h_t_limit, num_relations)
) – Label for all possible head to tail entity relations.
- Returns
output (
LsrModelOutput
)