Source code for sgnlp.models.csgec.utils

from math import inf
import os
import pathlib
import urllib
import requests


[docs]def download_tokenizer_files(base_url: str, local_path: str) -> None: """Download all required files for tokenizer from online storage. Args: base_url (str): Base url of storage location. local_path (str): path ot the folder on the local machine. """ tokenizer_files = [ "special_tokens_map.json", "tokenizer_config.json", "vocab.json", "merges.txt", ] file_paths = [urllib.parse.urljoin(base_url, path) for path in tokenizer_files] for fp in file_paths: download_url_file(fp, local_path)
[docs]def download_url_file(url: str, save_folder: str) -> None: """Helper method to download url file. Args: url (str): url file address string. save_folder (str): local folder name to save downloaded files. """ os.makedirs(save_folder, exist_ok=True) fn_start_pos = url.rfind("/") + 1 file_name = url[fn_start_pos:] save_file_name = pathlib.Path(save_folder).joinpath(file_name) req = requests.get(url) if req.status_code == requests.codes.ok: with open(save_file_name, "wb") as f: for data in req: f.write(data)
[docs]class Buffer: def __init__(self, max_len): self.max_len = max_len self.elements = [] def get_first_element(self): return self.elements.pop(0) def get_element(self, idx): return self.elements.pop(idx) def get_current_len(self): return len(self.elements) def __len__(self): return len(self.elements) def add_element(self, element): assert self.get_current_len() < self.max_len, "Exceeded max buffer length." self.elements.append(element) return def __repr__(self): return str(self.elements)
[docs]class Beam: def __init__(self, beam_size): self.beam_size = beam_size self.elements = [] def add_element(self, score, indices): # Note that the score should be a float and the indices should be a list of integers assert ( isinstance(score, float) and isinstance(score, int), ), "score should be a float or integer" assert isinstance(indices, list), "indices should be a list" assert all( [isinstance(x, int) for x in indices] ), "elements in indices should be integers" new_element = {"score": score, "indices": indices} # The number of elements should be at most equal to the beam size if len(self.elements) < self.beam_size: self._add_element(new_element) else: # If the current number of elements is equal to the beam_size, # we compare the lowest scoring element with the new element's score # and replace it if the new element's score is higher. # Otherwise no change is made if new_element["score"] > self.get_lowest_score(): self._remove_last_element() self._add_element(new_element) return def add_elements(self, scores_lst, indices_lst): for score, element in zip(scores_lst, indices_lst): self.add_element(score, element) return def get_elements(self): return self.elements def get_best_element(self): return self.elements[0] def _order_elements(self): self.elements = sorted(self.elements, key=lambda x: x["score"], reverse=True) return def _add_element(self, element): assert len(self.elements) < self.beam_size, "Beam is full" assert isinstance(element, dict), "element should be a dictionary object" self.elements.append(element) # This step is necessary to ensure the last element is always the lowest score self._order_elements() return def _remove_last_element(self): self.elements.pop(-1) return def get_lowest_score(self): if len(self.elements) == 0: return -inf return self.elements[-1]["score"]