import argparse
import json
import logging
import os
import pathlib
import pickle
from typing import Dict, Tuple, List, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.optim import Adam
from .config import (
UFDAdaptorGlobalConfig,
UFDAdaptorDomainConfig,
UFDCombineFeaturesMapConfig,
UFDClassifierConfig,
UFDEmbeddingConfig,
)
from .data_class import UFDArguments
from .tokenization import UFDTokenizer
from .modeling import (
UFDAdaptorGlobalModel,
UFDAdaptorDomainModel,
UFDCombineFeaturesMapModel,
UFDClassifierModel,
UFDEmbeddingModel,
)
logging.basicConfig(level=logging.DEBUG)
[docs]def create_unsupervised_models(
cfg: UFDArguments,
) -> Tuple[UFDAdaptorDomainModel, UFDAdaptorGlobalModel, UFDCombineFeaturesMapModel]:
"""Helper function to create the unsupervised model group.
Args:
cfg (UFDArguments): UFDArguments config load from configuration file
Returns:
Tuple[UFDAdaptorDomainModel, UFDAdaptorGlobalModel, UFDCombineFeaturesMapModel]:
return the initialize unsupervised model group.
"""
device = torch.device(cfg.device)
adaptor_domain_config = UFDAdaptorDomainConfig(
in_dim=cfg.train_args["in_dim"],
dim_hidden=cfg.train_args["dim_hidden"],
out_dim=cfg.train_args["out_dim"],
initrange=cfg.train_args["initrange"],
)
adaptor_global_config = UFDAdaptorGlobalConfig(
in_dim=cfg.train_args["in_dim"],
dim_hidden=cfg.train_args["dim_hidden"],
out_dim=cfg.train_args["out_dim"],
initrange=cfg.train_args["initrange"],
)
combine_features_map_config = UFDCombineFeaturesMapConfig(
embed_dim=cfg.train_args["in_dim"], initrange=cfg.train_args["initrange"]
)
return (
UFDAdaptorDomainModel(adaptor_domain_config).to(device),
UFDAdaptorGlobalModel(adaptor_global_config).to(device),
UFDCombineFeaturesMapModel(combine_features_map_config).to(device),
)
[docs]def load_trained_models(
cfg: UFDArguments,
source_domain: str,
target_language: str,
target_domain: str,
) -> Tuple[
UFDAdaptorDomainModel,
UFDAdaptorGlobalModel,
UFDCombineFeaturesMapModel,
UFDClassifierModel,
]:
"""Helper function to load pretrained config and model weights for both supervised and unsupervised
models.
Args:
cfg (UFDArguments): UFDArguments config load from configuration file
source_domain (str): name of source domain
target_language (str): name of target language
target_domain (str): name of target domain
Returns:
Tuple[ UFDAdaptorDomainModel, UFDAdaptorGlobalModel, UFDCombineFeaturesMapModel, UFDClassifierModel, ]:
return all supervised, unsupervised models with pretrained weights loaded.
"""
device = torch.device(cfg.device)
full_combi = (
cfg.model_folder
+ "/"
+ source_domain
+ "_"
+ target_language
+ "_"
+ target_domain
)
adaptor_domain_model_dir = full_combi + "_adaptor_domain/"
adaptor_global_model_dir = full_combi + "_adaptor_global/"
maper_model_dir = full_combi + "_maper/"
classifier_model_dir = full_combi + "_classifier/"
config_filename = cfg.eval_args["config_filename"]
model_filename = cfg.eval_args["model_filename"]
adaptor_domain_config = UFDAdaptorDomainConfig.from_pretrained(
adaptor_domain_model_dir + config_filename
)
adaptor_global_config = UFDAdaptorGlobalConfig.from_pretrained(
adaptor_global_model_dir + config_filename
)
maper_config = UFDCombineFeaturesMapConfig.from_pretrained(
maper_model_dir + config_filename
)
classifier_config = UFDClassifierConfig.from_pretrained(
classifier_model_dir + config_filename
)
return (
UFDAdaptorDomainModel.from_pretrained(
adaptor_domain_model_dir + model_filename, config=adaptor_domain_config
).to(device),
UFDAdaptorGlobalModel.from_pretrained(
adaptor_global_model_dir + model_filename, config=adaptor_global_config
).to(device),
UFDCombineFeaturesMapModel.from_pretrained(
maper_model_dir + model_filename, config=maper_config
).to(device),
UFDClassifierModel.from_pretrained(
classifier_model_dir + model_filename, config=classifier_config
).to(device),
)
[docs]def load_unlabelled(filename: str) -> List:
"""Helper function to load unlabelled dataset for unsupervised training.
Args:
filename (str): filename of dataset
Returns:
List: list of dataset by line.
"""
data = []
with open(filename, "r") as F:
for line in F:
# 0 is used in the research code
data.append([0, line.strip()])
return data
[docs]def load_labelled(filename: str) -> List:
"""Helper function to load labelled dataset for supervised training.
Args:
filename (str): filename of dataset
Returns:
List: list of dataset by line.
"""
data = []
with open(filename, "r") as F:
for line in F:
data.append(line.split("\t"))
return data
[docs]def create_train_embeddings(
cfg: UFDArguments, tokenizer: UFDTokenizer, model: UFDEmbeddingModel
) -> Dict:
"""Helper function to generate training dataset for supervised and unsupervised training.
Args:
cfg (UFDArguments): UFDArguments config load from configuration file
tokenizer (UFDTokenizer): UFD tokenizer class instance
model (UFDEmbeddingModel): UFD embedding model class instance
Returns:
Dict: dictionary of dataset embeddings for supervised and unsupervised dataset
"""
embeddings_dict = {}
source_domains_list = cfg.train_args["source_domains"] + ["raw"]
for source_domain in source_domains_list:
if source_domain == "raw":
# assume will only have one source language
filepath = (
f"{cfg.data_folder}/{cfg.train_args['unsupervised_dataset_filename']}"
)
dataset = load_unlabelled(filepath)
else:
filepath = f"{cfg.data_folder}/{cfg.train_args['source_language']}/{source_domain}/{cfg.train_args['train_filename']}"
dataset = load_labelled(filepath)
dataset_embedding = extract_embeddings(cfg, dataset, tokenizer, model)
embeddings_dict[source_domain] = dataset_embedding
return embeddings_dict
[docs]def create_val_test_embeddings(
cfg: UFDArguments,
tokenizer: UFDTokenizer,
model: UFDEmbeddingModel,
dataset_type: str,
) -> Dict:
"""Helper function to generate validation dataset for supervised and unsupervised training.
Args:
cfg (UFDArguments): UFDArguments config load from configuration file
tokenizer (UFDTokenizer): UFD tokenizer class instance
model (UFDEmbeddingModel): UFD embedding model class instance
Returns:
Dict: dictionary of dataset embeddings for supervised and unsupervised dataset
"""
embeddings_dict = {}
if dataset_type == "valid":
target_languages_list = cfg.train_args["target_languages"]
target_domains_list = cfg.train_args["target_domains"]
filename = cfg.train_args["val_filename"]
elif dataset_type == "test":
target_languages_list = cfg.eval_args["target_languages"]
target_domains_list = cfg.eval_args["target_domains"]
filename = cfg.eval_args["test_filename"]
for target_language in target_languages_list:
embeddings_dict[target_language] = {}
for target_domain in target_domains_list:
filepath = f"{cfg.data_folder}/{target_language}/{target_domain}/{filename}"
dataset = load_labelled(filepath)
dataset_embedding = extract_embeddings(cfg, dataset, tokenizer, model)
embeddings_dict[target_language][target_domain] = dataset_embedding
return embeddings_dict
[docs]def create_dataset_embedding(cfg: UFDArguments, dataset_type: str) -> Dict:
"""Main helper wrapper function to generate datasets.
Args:
cfg (UFDArguments): UFDArguments config load from configuration file
dataset_type (str): type of dataset to generate embeddings
Raises:
Exception: raise exception for unknown dataset types
Returns:
Dict: return dictionary of dataset embeddings.
"""
device = torch.device(cfg.device)
config = UFDEmbeddingConfig.from_pretrained(cfg.embedding_model_name)
model = UFDEmbeddingModel.from_pretrained(
cfg.embedding_model_name, config=config
).to(device)
tokenizer = UFDTokenizer.from_pretrained(cfg.embedding_model_name)
if dataset_type == "train":
dataset_embedding_dict = create_train_embeddings(
cfg,
tokenizer,
model,
)
elif dataset_type in ["test", "valid"]:
dataset_embedding_dict = create_val_test_embeddings(
cfg, tokenizer, model, dataset_type
)
else:
raise Exception(
"Invalid value for dataset_type, dataset_type should be train/ valid/ test"
)
return dataset_embedding_dict
[docs]def create_classifiers(cfg: UFDArguments) -> Dict:
"""Helper function to generate all classifier models, criterion and optimizer.
One set of classifier, criterion and optimizer required per source domain.
Args:
cfg (UFDArguments): UFDArguments config load from configuration file
Returns:
Dict: return dictionary of models, criterion and optimizer group by source domain.
"""
device = torch.device(cfg.device)
config = UFDClassifierConfig(
embed_dim=cfg.train_args["out_dim"],
num_class=cfg.train_args["num_class"],
initrange=cfg.train_args["initrange"],
)
classifiers = {}
for domain in cfg.train_args["source_domains"]:
classifiers[domain] = {}
classifiers[domain]["model"] = UFDClassifierModel(config).to(device)
classifiers[domain]["criterion"] = nn.CrossEntropyLoss().to(device)
classifiers[domain]["optimizer"] = Adam(
classifiers[domain]["model"].parameters(),
lr=cfg.train_args["learning_rate"],
)
return classifiers
[docs]def get_source2target_domain_mapping(
source_domains: List[str], target_domains: List[str]
) -> Dict[str, List[str]]:
"""Helper function to return cross domains keys for each domains.
Args:
source_domains (List[str]): list of all source domains
target_domains (List[str]): list of all target domains
Returns:
Dict[str, List[str]]: return a list of cross domains per domain keys.
"""
mapping = {}
for dom in source_domains:
mapping[dom] = [d for d in target_domains if d != dom]
return mapping
[docs]def generate_train_val_dataset(cfg: UFDArguments) -> Tuple[Dict[str, float]]:
"""Helper function to generate train and validation datasets.
Load dataset object from cache if available, else call the dataset embeddings creation methods.
Args:
cfg (UFDArguments): UFDArguments config load from configuration file
Returns:
Tuple(Dict[str, float]): return the generated train and validation dictionaries.
"""
use_cache = (
"train_cache_filename" in cfg.train_args.keys()
and "val_cache_filename" in cfg.train_args.keys()
)
if use_cache:
train_cache_path = str(
pathlib.Path(cfg.cache_folder).joinpath(
cfg.train_args["train_cache_filename"]
)
)
if os.path.isfile(train_cache_path):
with open(train_cache_path, "rb") as handle:
train_data = pickle.load(handle)
logging.info("Train data loaded from cache")
for source_domain in cfg.train_args["source_domains"]:
assert (
source_domain in train_data.keys()
), "Source domain key does not exist in cached data, consider deleting the cache and rerun the code"
else:
train_data = create_dataset_embedding(cfg, dataset_type="train")
if not os.path.isdir(cfg.cache_folder):
os.mkdir(cfg.cache_folder)
with open(train_cache_path, "wb") as handle:
pickle.dump(train_data, handle)
logging.info("Train data saved in cache")
valid_cache_path = str(
pathlib.Path(cfg.cache_folder).joinpath(
cfg.train_args["val_cache_filename"]
)
)
if os.path.isfile(valid_cache_path):
with open(valid_cache_path, "rb") as handle:
valid_data = pickle.load(handle)
logging.info("Validation data loaded from cache")
for target_language in cfg.train_args["target_languages"]:
assert (
target_language in valid_data.keys()
), "Target language key does not exist in cached data, consider deleting the cache and rerun the code"
for target_domain in cfg.train_args["target_domains"]:
assert (
target_domain in valid_data[target_language].keys()
), "Target domain key does not exist in cached data, consider deleting the cache and rerun the code"
else:
valid_data = create_dataset_embedding(cfg, dataset_type="valid")
with open(valid_cache_path, "wb") as handle:
pickle.dump(valid_data, handle)
logging.info("Validation data saved in cache")
else:
train_data = create_dataset_embedding(cfg, dataset_type="train")
valid_data = create_dataset_embedding(cfg, dataset_type="valid")
return train_data, valid_data
[docs]def parse_args_and_load_config(config_path: str = "config/ufd_config.json"):
"""Args parser helper method
Returns:
UFDArguments: UFDArguments instance with parsed args
"""
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=config_path)
args = parser.parse_args()
with open(args.config, "r") as cfg_file:
cfg = json.load(cfg_file)
ufd_args = UFDArguments(**cfg)
return ufd_args