Source code for sgnlp.models.span_extraction.postprocess
from typing import List, Dict, Tuple, Union
from transformers.data.processors.squad import SquadFeatures, SquadExample
from transformers.modeling_outputs import QuestionAnsweringModelOutput
from .utils import (
RawResult,
get_best_predictions,
)
[docs]class RecconSpanExtractionPostprocessor:
"""Class to initialise the Postprocessor for RecconSpanExtraction model.
Class to postprocess RecconSpanExtractionModel raw output to get the causal span
and probabilities
Args:
threshold (float, optional): probability threshold value to extract causal span.
Defaults to 0.7.
"""
def __init__(self, threshold: float = 0.7) -> None:
self.threshold = threshold
[docs] def __call__(
self,
raw_pred: QuestionAnsweringModelOutput,
evidences: List[Dict[str, Union[int, str]]],
examples: List[SquadExample],
features: List[SquadFeatures],
) -> Tuple[List[List[str]], List[List[int]], List[List[Union[int, float]]]]:
"""Convert raw prediction (logits) to:
1. list of list of spans
2. List of list of integer to indicate if corresponding span is causal
3. List of list of float to indicate probability of corresponding
span being causal
Args:
raw_pred (QuestionAnsweringModelOutput): output of RecconSpanExtractionModel
evidences (List[Dict[str, Union[int, str]]]): List of evidence utterances
examples (List[SquadExample]): List of SquadExample instance - output of RecconSpanExtractionPreprocessor
features ( List[SquadFeatures]): List of SquadFeatures instance - output of RecconSpanExtractionPreprocessor
Returns:
Tuple[List[List[str]], List[List[int]], List[List[Union[int, float]]]]: 1. List of list of spans
2. List of list of integer to
indicate if corresponding span is causal
3. List of list of int/float to indicate
probability of corresponding span being causal.
-1 indicates span is not causal
"""
all_results = self._process_raw_pred(raw_pred)
answers = get_best_predictions(
all_examples=examples,
all_features=features,
all_results=all_results,
n_best_size=20,
max_answer_length=200,
do_lower_case=False,
verbose_logging=False,
version_2_with_negative=True,
null_score_diff_threshold=False,
)
context, evidence_span, probability = self._process_answers(answers, evidences)
return context, evidence_span, probability
def _process_raw_pred(
self, raw_pred: QuestionAnsweringModelOutput
) -> List[RawResult]:
"""Process raw output into a list of RawResult which can be used with
get_best_prediction()
Args:
raw_pred (QuestionAnsweringModelOutput): output of RecconSpanExtractionModel
Returns:
List[RawResult]: list of RawResult which can be used with get_best_predictions()
"""
all_results = []
for i in range(len(raw_pred["start_logits"])):
result = RawResult(
unique_id=i + 1000000000,
start_logits=raw_pred["start_logits"][i].tolist(),
end_logits=raw_pred["end_logits"][i].tolist(),
)
all_results.append(result)
return all_results
def _process_answers(
self, answers: List[Dict[str, any]], evidences: List[Dict[str, Union[int, str]]]
) -> Tuple[List[List[str]], List[List[int]], List[List[float]]]:
"""Post process prediction generated from get_best_predictions()
Args:
answers (List[Dict[str, any]]): List of predictions from get_best_predictions()
evidences (List[Dict[str, Union[int, str]]]): List of evidence utterance
Returns:
Tuple[List[List[str]], List[List[int]], List[List[float]]]: Return processed lists of
the full context, evidence_span and probabilities score.
"""
context = []
evidence_span = []
probability = []
for ans, evid in zip(answers, evidences):
span = ans["answer"][0]
prob = ans["probability"][0]
lower_than_threshold = prob < self.threshold
span_equals_evidence = evid["evidence"] == span
if not span or lower_than_threshold:
ctx = [evid["evidence"]]
span = [0]
prob = [-1]
elif not span_equals_evidence:
ctx, span, prob = self._process_span(evid["evidence"], span, prob)
else:
ctx = [evid["evidence"]]
span = [1]
prob = [prob]
context.append(ctx)
evidence_span.append(span)
probability.append(prob)
return context, evidence_span, probability
def _process_span(
self, evidence: str, span: str, prob: float
) -> Tuple[List[str], List[str], List[float]]:
"""Helper function to split evidence string if evidence_span is a sub-string of the evidence string.
Args:
evidence (str): the full evidence string
span (str): evidence span which is a substring of the full evidence string
prob (float): probability score for the evidence span
Returns:
Tuple[List[str], List[str], List[float]]: lists containing the splitted evidence spans and probability
"""
evidence_split = evidence.split(span)
if evidence.startswith(span):
ctx_list = [span, evidence_split[1]]
evidence_span = [1, 0]
prob_list = [prob, -1]
elif evidence.endswith(span):
ctx_list = [evidence_split[0], span]
evidence_span = [0, 1]
prob_list = [-1, prob]
else:
ctx_list = [evidence_split[0], span, evidence_split[1]]
evidence_span = [0, 1, 0]
prob_list = [-1, prob, -1]
return ctx_list, evidence_span, prob_list