"""Functionality for postprocessing :class:`~sgnlp.models.lsr.modeling.LsrModelOutput`"""
import json
import numpy as np
from torch import sigmoid
from .utils import idx2ht
[docs]class LsrPostprocessor:
"""This class processes :class:`~sgnlp.models.lsr.modeling.LsrModelOutput` to a readable format.
Args:
rel2id (:obj:`dict`):
Relation to id mapping.
rel_info (:obj:`dict`):
Relation to description mapping.
pred_threshold (:obj:`float`, `optional`, defaults to 0.3):
Threshold for relation prediction to be returned.
"""
def __init__(self, rel2id: dict, rel_info: dict, pred_threshold: float = 0.3):
self.rel2id = rel2id
self.id2rel = {v: k for k, v in self.rel2id.items()}
self.rel_info = rel_info
self.pred_threshold = pred_threshold
[docs] @staticmethod
def from_file_paths(
rel2id_path: str, rel_info_path: str, pred_threshold: float = 0.3
):
"""Constructs LsrPostprocessor from relevant DocRED files.
Args:
rel2id_path (:obj:`str`):
Path to relation to id mapping file.
rel_info_path (:obj:`str`):
Path to relation info file. This is a mapping from relation to relation description.
pred_threshold (:obj:`float`, `optional`, defaults to 0.3):
Threshold for relation prediction to be returned.
Returns:
postprocessor (:class:`~sgnlp.models.lsr.postprocess.LsrPostprocessor`)
"""
rel2id = json.load(open(rel2id_path))
rel_info = json.load(open(rel_info_path))
rel_info["Na"] = "No relation"
return LsrPostprocessor(rel2id, rel_info, pred_threshold)
[docs] def __call__(self, prediction, data):
"""
Args:
prediction (:class:`torch.FloatTensor`):
Prediction of :class:`~sgnlp.models.lsr.modeling.LsrModelOutput`.
data:
DocRED-like data that was used as input in preprocessing step.
Returns:
List of dictionary that includes the document, the entity clusters found in the document, and
the predicted relations between the entity clusters.
"""
output = []
for prediction_instance, data_instance in zip(prediction, data):
document = [
item for sublist in data_instance["sents"] for item in sublist
] # Flatten nested list tokens
num_entities = len(data_instance["vertexSet"])
total_relation_combinations = num_entities * (num_entities - 1)
pred = sigmoid(prediction_instance).data.cpu().numpy()
pred = pred[:total_relation_combinations]
above_threshold_indices = zip(*np.where(pred > self.pred_threshold))
relations = []
for h_t_idx, rel_idx in above_threshold_indices:
h_idx, t_idx = idx2ht(h_t_idx, num_entities)
rel_description = self.rel_info[self.id2rel[rel_idx]]
# Typecasts are to allow JSON serializable (Numpy types generally not json serializable by default)
relations.append(
{
"score": float(pred[h_t_idx, rel_idx]),
"relation": rel_description,
"object_idx": int(h_idx),
"subject_idx": int(t_idx),
}
)
# Compute sentence start indices
sentence_start_idx = [0]
sentence_start_idx_counter = 0
for sent in data_instance["sents"]:
sentence_start_idx_counter += len(sent)
sentence_start_idx.append(sentence_start_idx_counter)
clusters = []
for vertex_set in data_instance["vertexSet"]:
cluster = []
for entity in vertex_set:
sent_id = entity["sent_id"] # sent_id that entity appears in
pos_adjustment = sentence_start_idx[sent_id] # start idx of sent
pos = list(entity["pos"])
pos = [
pos[0] + pos_adjustment,
pos[1] + pos_adjustment,
] # adjust pos by adding start of sentence idx
cluster.append(pos)
clusters.append(cluster)
output.append(
{"clusters": clusters, "document": document, "relations": relations}
)
return output