from typing import Dict, List
import torch
from transformers import PreTrainedTokenizer, PreTrainedModel
from transformers.tokenization_utils_base import BatchEncoding
from .config import UFDEmbeddingConfig
from .tokenization import UFDTokenizer
from .modeling import UFDEmbeddingModel
[docs]class UFDPreprocessor:
"""
Class for preprocessing a raw text to a batch of tensors for the UFDModel to predict on.
Inject tokenizer and/or embedding model instances via the 'tokenizer' and 'embedding_model' input args,
or pass in the tokenizer name and/or embedding model name via the 'tokenizer_name' and 'embedding_model_name'
input args to create from_pretrained.
"""
def __init__(
self,
tokenizer: PreTrainedTokenizer = None,
embedding_model: PreTrainedModel = None,
tokenizer_name: str = "xlm-roberta-large",
embedding_model_name: str = "xlm-roberta-large",
device: torch.device = torch.device("cpu"),
):
self.device = device
if tokenizer is not None:
self.tokenizer = tokenizer
else:
self.tokenizer = UFDTokenizer.from_pretrained(tokenizer_name)
if embedding_model is not None:
self.embedding_model = embedding_model
else:
embedding_config = UFDEmbeddingConfig.from_pretrained(embedding_model_name)
self.embedding_model = UFDEmbeddingModel.from_pretrained(
embedding_model_name, config=embedding_config
).to(device)
[docs] def __call__(self, data_batch: List[str]) -> BatchEncoding:
"""
Main method to start preprocessing.
Args:
data_batch (List[str]): list of raw text to process.
Returns:
BatchEncoding: return a BatchEncoding instance with key 'input_ids' and embedded values of data batch.
"""
text_embeddings = self._get_embedding(data_batch)
mean_features = torch.mean(
text_embeddings[0], dim=1
) # calculate the mean of output layer of embedding model
return BatchEncoding({"data_batch": mean_features})
def _get_embedding(self, data_batch: List[str]) -> Dict[str, BatchEncoding]:
"""
Method to generate tensor from a list of text.
Args:
text (List[str]): list of input text.
Returns:
Dict[str, BatchEncoding]: tensor generated from input text.
"""
self.embedding_model.eval()
with torch.no_grad():
tokens = self.tokenizer(data_batch, padding=True).to(self.device)
text_embedding = self.embedding_model(**tokens)
return text_embedding