Source code for sgnlp.models.sentic_gcn.postprocess

from typing import Dict, List, Union

import torch.nn.functional as F

from .preprocess import SenticGCNData, SenticGCNBertData
from .modeling import SenticGCNModelOutput, SenticGCNBertModelOutput


[docs]class SenticGCNBasePostprocessor: """ Base postprocessor class providing common post processing functions. """ def __init__(self, return_full_text: bool = False, return_aspects_text: bool = False) -> None: self.return_full_text = return_full_text self.return_aspects_text = return_aspects_text
[docs] def __call__( self, processed_inputs: List[Union[SenticGCNData, SenticGCNBertData]], model_outputs: Union[SenticGCNModelOutput, SenticGCNBertModelOutput], ) -> List[Dict[str, Union[List[str], List[int], float]]]: # Get predictions probabilities = F.softmax(model_outputs.logits, dim=-1).detach().numpy() predictions = [probabilities.argmax(axis=-1)[idx] - 1 for idx in range(len(probabilities))] # Process output outputs = [] for processed_input, prediction in zip(processed_inputs, predictions): exists = False # Check to see if the full_text_tokens already exists # If found, append the aspect_token_index, prediction and optionally aspect texts. for idx, proc_output in enumerate(outputs): if proc_output["sentence"] == processed_input.full_text_tokens: exists = True outputs[idx]["aspects"].append(processed_input.aspect_token_indexes) outputs[idx]["labels"].append(int(prediction)) if self.return_aspects_text: outputs[idx]["aspects_text"].append(processed_input.aspect) break if exists: continue processed_dict = {} processed_dict["sentence"] = processed_input.full_text_tokens processed_dict["aspects"] = [processed_input.aspect_token_indexes] processed_dict["labels"] = [int(prediction)] if self.return_full_text: processed_dict["full_text"] = processed_input.full_text if self.return_aspects_text: processed_dict["aspects_text"] = [processed_input.aspect] outputs.append(processed_dict) return outputs
[docs]class SenticGCNPostprocessor(SenticGCNBasePostprocessor): """ Class to initialise the Postprocessor for SenticGCNModel. Class to postprocess SenticGCNModel output to get a list of input text tokens, aspect token index and prediction labels. Args: return_full_text (bool): Flag to indicate if the full text should be included in the output. return_aspects_text (bool): Flag to indicate if the list of aspects text should be included in the output. """ def __init__(self, return_full_text: bool = False, return_aspects_text: bool = False) -> None: super().__init__(return_full_text=return_full_text, return_aspects_text=return_aspects_text)
[docs]class SenticGCNBertPostprocessor(SenticGCNBasePostprocessor): """ Class to initialise the Postprocessor for SenticGCNBertModel. Class to postprocess SenticGCNBertModel output to get a list of input text tokens, aspect token index and prediction labels. Args: return_full_text (bool): Flag to indicate if the full text should be included in the output. return_aspects_text (bool): Flag to indicate if the list of aspects text should be included in the output. """ def __init__(self, return_full_text: bool = False, return_aspects_text: bool = False) -> None: super().__init__(return_full_text=return_full_text, return_aspects_text=return_aspects_text)