Source code for sgnlp.models.span_extraction.preprocess
from typing import List, Dict, Tuple, Optional, Union
import torch
from transformers import PreTrainedTokenizer, BatchEncoding
from transformers.data.processors.squad import SquadFeatures, SquadExample
from .tokenization import RecconSpanExtractionTokenizer
from .utils import load_examples
[docs]class RecconSpanExtractionPreprocessor:
"""Class to initialise the Preprocessor for RecconSpanExtraction model.
Preprocesses inputs and tokenises them so they can be used with RecconSpanExtractionModel
Args:
tokenizer (Optional[PreTrainedTokenizer], optional): Tokenizer to use for preprocessor. Defaults to None.
max_length (int, optional): maximum length of truncated tokens. Defaults to 512.
"""
def __init__(
self,
tokenizer: Optional[PreTrainedTokenizer] = None,
):
if tokenizer is None:
self.tokenizer = RecconSpanExtractionTokenizer.from_pretrained(
"mrm8488/spanbert-finetuned-squadv2"
)
else:
self.tokenizer = tokenizer
[docs] def __call__(
self, data_batch: Dict[str, List[str]]
) -> Tuple[
BatchEncoding,
List[Dict[str, Union[int, str]]],
List[SquadExample],
List[SquadFeatures],
]:
"""Preprocess data then tokenize, so it can be used in RecconSpanExtractionModel
Args:
data_batch (Dict[str, List[str]]):
The dictionary should contain the following keys 'emotion', 'target_utterance', 'evidence_utterance',
and 'conversation_history'. Each value should be a list of strings, with each list being of same length.
Returns:
Tuple[ BatchEncoding, List[Dict[str, Union[int, str]]], List[SquadExample], List[SquadFeatures] ]:
1. BatchEncoding output from tokenizer
2. List of evidence utterances
3. List of SquadExample output from load_examples() function
4. List of SquadFeatures output from load_examples() function
"""
self._check_values_len(data_batch)
concatenated_batch, evidences = self._concatenate_batch(data_batch)
dataset, examples, features = load_examples(
concatenated_batch, self.tokenizer, evaluate=True, output_examples=True
)
input_ids = [torch.unsqueeze(instance[0], 0) for instance in dataset]
attention_mask = [torch.unsqueeze(instance[1], 0) for instance in dataset]
token_type_ids = [torch.unsqueeze(instance[2], 0) for instance in dataset]
output = {
"input_ids": torch.cat(input_ids, axis=0),
"attention_mask": torch.cat(attention_mask, axis=0),
"token_type_ids": torch.cat(token_type_ids, axis=0),
}
output = BatchEncoding(output)
return output, evidences, examples, features
def _concatenate_batch(
self, data_batch: Dict[str, List[str]]
) -> Tuple[List[Dict[str, any]], List[Dict[str, any]]]:
"""Takes in data batch and converts them into a list of string which can be
used with the tokenizer
Args:
data_batch (Dict[str, List[str]]): The dictionary should look like this:
{'emotion': ['happiness'],
'target_utterance': ['......'],
'evidence_utterance': ['......'],
'conversation_history': ['......']}
The length of each value must be the same
Returns:
Tuple[List[Dict[str, any], List[Dict[str, any]]]]:
1. list of concatenated string for each instance
2. list of evidence utterances for each instance
"""
concatenated_batch = []
evidences_batch = []
emotion_batch = data_batch["emotion"]
target_utterance_batch = data_batch["target_utterance"]
evidence_utterance_batch = data_batch["evidence_utterance"]
conversation_history_batch = data_batch["conversation_history"]
for i, (
emotion,
target_utterance,
evidence_utterance,
conversation_history,
) in enumerate(
zip(
emotion_batch,
target_utterance_batch,
evidence_utterance_batch,
conversation_history_batch,
)
):
concatenated_qns = (
"The target utterance is "
+ target_utterance
+ "The evidence utterance is "
+ evidence_utterance
+ "What is the causal span from context that is relevant to the target utterance's emotion "
+ emotion
+ " ?"
)
inputs = {
"id": i,
"question": concatenated_qns,
"answers": [{"text": " ", "answer_start": 0}],
"is_impossible": False,
}
instance_dict = {"context": conversation_history, "qas": [inputs]}
concatenated_batch.append(instance_dict)
evidence = {"id": i, "evidence": evidence_utterance}
evidences_batch.append(evidence)
return concatenated_batch, evidences_batch
def _concatenate_instance(
self,
emotion: str,
target_utterance: str,
evidence_utterance: str,
conversation_history: str,
) -> str:
"""Concatenate a single instance into a single string
Args:
emotion (str): emotion of instance
target_utterance (str): target_utterance of instance
evidence_utterance (str): evidence utterance of instance
conversation_history (str): conversation history of instance
Returns:
str: concated string of a single instance
"""
concatenated_text = (
" "
+ emotion
+ " <SEP> "
+ target_utterance
+ " <SEP> "
+ evidence_utterance
+ " <SEP> "
+ conversation_history
)
return concatenated_text
def _check_values_len(self, data_batch: Dict[str, List[str]]):
"""Check if the length of all values in the Dict are the same
Args:
data_batch (Dict[str, List[str]]): data_batch input from __call__ method
"""
values_len = [len(v) for _, v in data_batch.items()]
unique_len = len(set(values_len))
assert unique_len == 1, "Length of values are not consistent across"