Source code for sgnlp.models.emotion_entailment.preprocess

from typing import List, Dict, Optional

from transformers import PreTrainedTokenizer, BatchEncoding

from .tokenization import RecconEmotionEntailmentTokenizer


[docs]class RecconEmotionEntailmentPreprocessor: """Class to initialise the Preprocessor for RecconEmotionEntailment model. Preprocesses inputs and tokenises them so they can be used with RecconEmotionEntailmentModel Args: tokenizer (Optional[PreTrainedTokenizer], optional): Tokenizer to use for preprocessor. Defaults to None. max_length (int, optional): maximum length of truncated tokens. Defaults to 512. """ def __init__( self, tokenizer: Optional[PreTrainedTokenizer] = None, max_length: int = 512, ): self.max_length = max_length if tokenizer is None: self.tokenizer = RecconEmotionEntailmentTokenizer.from_pretrained( "roberta-base" ) else: self.tokenizer = tokenizer
[docs] def __call__(self, data_batch: Dict[str, List[str]]) -> BatchEncoding: """Preprocess data then tokenize, so it can be used in RecconEmotionEntailmentModel Args: data_batch (Dict[str, List[str]]): The dictionary should contain the following keys 'emotion', 'target_utterance', 'evidence_utterance', and 'conversation_history'. Each value should be a list of strings, with each list being of same length. Returns: BatchEncoding: BatchEncoding instance returned from self.tokenizer """ self._check_values_len(data_batch) concatenated_batch = self._concatenate_batch(data_batch) output = self.tokenizer( concatenated_batch, max_length=self.max_length, padding="max_length", truncation=True, return_token_type_ids=True, return_tensors="pt", ) return output
def _concatenate_batch(self, data_batch: Dict[str, List[str]]) -> List[str]: """Takes in data batch and converts them into a list of string which can be used with the tokenizer Args: data_batch (Dict[str, List[str]]): The dictionary should look like this: {'emotion': ['happiness'], 'target_utterance': ['......'], 'evidence_utterance': ['......'], 'conversation_history': ['......']} The length of each value must be the same Returns: List[str]: list of concatenated string for each instance """ concatenated_batch = [] emotion_batch = data_batch["emotion"] target_utterance_batch = data_batch["target_utterance"] evidence_utterance_batch = data_batch["evidence_utterance"] conversation_history_batch = data_batch["conversation_history"] for emotion, target_utterance, evidence_utterance, conversation_history in zip( emotion_batch, target_utterance_batch, evidence_utterance_batch, conversation_history_batch, ): concatenated_instance = self._concatenate_instance( emotion, target_utterance, evidence_utterance, conversation_history ) concatenated_batch.append(concatenated_instance) return concatenated_batch def _concatenate_instance( self, emotion: str, target_utterance: str, evidence_utterance: str, conversation_history: str, ) -> str: """Concatenate a single instance into a single string Args: emotion (str): emotion of instance target_utterance (str): target_utterance of instance evidence_utterance (str): evidence utterance of instance conversation_history (str): conversation history of instance Returns: str: concated string of a single instance """ concatenated_text = ( " " + emotion + " <SEP> " + target_utterance + " <SEP> " + evidence_utterance + " <SEP> " + conversation_history ) return concatenated_text def _check_values_len(self, data_batch: Dict[str, List[str]]): """Check if the length of all values in the Dict are the same Args: data_batch (Dict[str, List[str]]): data_batch input from __call__ method """ values_len = [len(v) for _, v in data_batch.items()] unique_len = len(set(values_len)) assert unique_len == 1, "Length of values are not consistent across"