Source code for sgnlp.models.lsr.train

"""
Training or finetuning a LSR model on DocRED dataset.
"""
import argparse
import csv
import json
import logging
import os

import numpy as np
import torch
from sklearn import metrics
from torch.utils.data.dataloader import DataLoader, default_collate
from transformers import set_seed

from .config import LsrConfig
from .modeling import LsrModel, LsrModelOutput
from .preprocess import LsrPreprocessor
from .utils import h_t_idx_generator

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


[docs]class LsrMetrics: """ For calculating metrics for LsrModel. Computes precision, recall, f1, auc for precision vs recall, and the optimal prediction threshold (theta). This is modified from the original. The original additionally computes an ignore score which ignores facts seen in training set. """ def __init__(self, num_relations): self.num_relations = num_relations self.test_result = [] self.total_recall = 0 def add_batch(self, model_output: LsrModelOutput, batch_data): predict_re = torch.sigmoid(model_output.prediction) predict_re = predict_re.data.cpu().numpy() for i, label in enumerate(batch_data["labels"]): self.total_recall += len(label) vertex_set_length = batch_data["entity_num_list"][ i ] # the number of entities in each instance. for j, (h_idx, t_idx) in enumerate(h_t_idx_generator(vertex_set_length)): for r in range(1, self.num_relations): result_tuple = ( (h_idx, t_idx, r) in label, float(predict_re[i, j, r]), ) self.test_result.append(result_tuple)
[docs] def compute(self, input_theta=None): """ Computes metrics based on data added so far. Args: input_theta (`optional`, `float`): Prediction threshold. Provide a value between 0 to 1 if you want to compute the precision and recall for that specific threshold. Otherwise the optimal based on f1 score will be computed for you. """ # Sorts in descending order by predicted value self.test_result.sort(key=lambda x: x[1], reverse=True) precision = [] recall = [] correct = 0 w = 0 if self.total_recall == 0: self.total_recall = 1 # for test for i, item in enumerate(self.test_result): correct += item[0] recall.append(float(correct) / (i + 1)) precision.append(float(correct) / self.total_recall) if input_theta is not None and item[1] > input_theta: w = i precision = np.asarray(precision, dtype="float32") recall = np.asarray(recall, dtype="float32") f1_arr = 2 * precision * recall / (precision + recall + 1e-20) auc = metrics.auc(x=precision, y=recall) if input_theta is None: f1 = f1_arr.max() f1_pos = f1_arr.argmax() best_theta = self.test_result[f1_pos][1] return best_theta, f1, precision[f1_pos], recall[f1_pos], auc else: return input_theta, f1_arr[w], precision[w], recall[w], auc
[docs]class DocredDataset(torch.utils.data.Dataset): def __init__(self, json_file, preprocessor): self.data = json.load(open(json_file)) self.preprocessed_data = preprocessor(self.data) def __len__(self): return len(self.data) def __getitem__(self, idx): instance = {} for key in self.preprocessed_data.keys(): instance[key] = self.preprocessed_data[key][idx] return instance
[docs]def lsr_collate_fn(batch): """ Manually processes labels and uses default for the rest. """ labels = [] for instance in batch: labels_instance = instance.pop("labels") labels.append(labels_instance) collated_data = default_collate(batch) collated_data["labels"] = labels return collated_data
[docs]class MyAdagrad(torch.optim.Optimizer): """ Modification of the Adagrad optimizer that allows to specify an initial accumulator value. This mimics the behavior of the default Adagrad implementation in Tensorflow. The default PyTorch Adagrad uses 0 for initial accumulator value. Arguments: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-2) lr_decay (float, optional): learning rate decay (default: 0) init_accu_value (float, optional): initial accumulater value. weight_decay (float, optional): weight decay (L2 penalty) (default: 0) """ def __init__( self, params, lr=1e-2, lr_decay=0, init_accu_value=0.1, weight_decay=0 ): defaults = dict( lr=lr, lr_decay=lr_decay, init_accu_value=init_accu_value, weight_decay=weight_decay, ) super(MyAdagrad, self).__init__(params, defaults) for group in self.param_groups: for p in group["params"]: state = self.state[p] state["step"] = 0 state["sum"] = ( torch.ones(p.data.size()).type_as(p.data) * init_accu_value ) def share_memory(self): for group in self.param_groups: for p in group["params"]: state = self.state[p] state["sum"].share_memory_()
[docs] def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad.data state = self.state[p] state["step"] += 1 if group["weight_decay"] != 0: if p.grad.data.is_sparse: raise RuntimeError( "weight_decay option is not compatible with sparse gradients " ) grad = grad.add(group["weight_decay"], p.data) clr = group["lr"] / (1 + (state["step"] - 1) * group["lr_decay"]) if p.grad.data.is_sparse: grad = ( grad.coalesce() ) # the update is non-linear so indices must be unique grad_indices = grad._indices() grad_values = grad._values() size = torch.Size([x for x in grad.size()]) def make_sparse(values): constructor = type(p.grad.data) if grad_indices.dim() == 0 or values.dim() == 0: return constructor() return constructor(grad_indices, values, size) state["sum"].add_(make_sparse(grad_values.pow(2))) std = state["sum"]._sparse_mask(grad) std_values = std._values().sqrt_().add_(1e-10) p.data.add_(-clr, make_sparse(grad_values / std_values)) else: state["sum"].addcmul_(1, grad, grad) std = state["sum"].sqrt().add_(1e-10) p.data.addcdiv_(-clr, grad, std) return loss
[docs]def get_optimizer(name, parameters, lr, l2=0): if name == "sgd": return torch.optim.SGD(parameters, lr=lr, weight_decay=l2) elif name in ["adagrad", "myadagrad"]: # use custom adagrad to allow for init accumulator value return MyAdagrad(parameters, lr=lr, init_accu_value=0.1, weight_decay=l2) elif name == "adam": return torch.optim.Adam(parameters, lr, weight_decay=l2) elif name == "adamw": return torch.optim.AdamW(parameters, lr=lr, weight_decay=l2) elif name == "adamax": return torch.optim.Adamax(parameters, weight_decay=l2) elif name == "adadelta": return torch.optim.Adadelta(parameters, lr=lr, weight_decay=l2) else: raise Exception("Unsupported optimizer: {}".format(name))
[docs]def parse_args(): parser = argparse.ArgumentParser( description="Train a relation extraction model based on latent structure refinement." ) parser.add_argument( "--train_file", type=str, required=True, help="A json file containing the training data.", ) parser.add_argument( "--validation_file", type=str, default=None, help="A json file containing the validation data.", ) parser.add_argument( "--output_dir", type=str, default=None, help="Output directory path." ) parser.add_argument( "--metadata_dir", type=str, required=True, help="Path to docred metadata directory.", ) # Model arguments parser.add_argument( "--model_weights_path", type=str, default=None, help="Provide a path to model weights if you want to finetune from a checkpoint.", ) parser.add_argument( "--model_config_path", type=str, default=None, help="Provide a config if you want to override the defaults. " "To use bert encoder layer, specify in config file.", ) parser.add_argument( "--pretrained_embeddings_path", type=str, default=None, help="Provide a path to pretrained embeddings if you want to want to use them.", ) # Training arguments parser.add_argument( "--lr", type=float, default=1e-3, help="Applies to sgd and adagrad." ) parser.add_argument( "--lr_decay", type=float, default=0.98, help="Learning rate decay rate." ) parser.add_argument( "--decay_epoch", type=int, default=20, help="Decay learning rate after this epoch.", ) parser.add_argument( "--evaluate_epoch", type=int, default=30, help="Evaluate after this epoch." ) parser.add_argument( "--optim", choices=["sgd", "adagrad", "adam", "adamw", "adamax"], default="adam", help="Choice of optimizer.", ) parser.add_argument( "--weight_decay", type=float, default=0, help="L2 weight decay." ) parser.add_argument( "--num_epoch", type=int, default=200, help="Number of total training epochs." ) parser.add_argument( "--batch_size", type=int, default=20, help="Training batch size." ) parser.add_argument( "--max_grad_norm", type=float, default=5.0, help="Gradient clipping." ) parser.add_argument( "--seed", type=int, default=0, help="A seed for reproducible training." ) parser.add_argument( "--use_gpu", type=bool, default=True, help="Whether you want to use GPU for training.", ) parser.add_argument( "--use_wandb", type=bool, default=False, help="Whether to use wandb to monitor training.", ) parser.add_argument( "--wandb_run_name", type=str, default=None, help="Wandb run name." ) args = parser.parse_args() return args
[docs]def train_model(args): logger.info(f"Training arguments: {vars(args)}") # Defaults to cpu if gpu is unavailable device = ( torch.device("cuda") if args.use_gpu and torch.cuda.is_available() else torch.device("cpu") ) logger.info(f"Using device: {device}") # Make output dir, save training args, create metrics files if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) os.makedirs(os.path.join(args.output_dir, "best_f1"), exist_ok=True) os.makedirs(os.path.join(args.output_dir, "final"), exist_ok=True) with open(os.path.join(args.output_dir, "training_args.json"), "w") as fp: json.dump(vars(args), fp, sort_keys=True, indent=4) # create metric output files train_metrics_file = os.path.join(args.output_dir, "train_metrics.csv") val_metrics_file = os.path.join(args.output_dir, "val_metrics.csv") fieldnames = ["epoch", "loss", "best_theta", "f1", "precision", "recall", "auc"] with open(train_metrics_file, "w") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() with open(val_metrics_file, "w") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() if args.seed is not None: set_seed(args.seed) # Load config if provided else initialize with defaults if args.model_config_path: config = LsrConfig.from_json_file(json_file=args.model_config_path) else: config = LsrConfig() # Load model weights if provided if args.model_weights_path: model = LsrModel.from_pretrained(args.model_weights_path, config=config) else: model = LsrModel(config) if args.use_wandb: import wandb wandb.init(project="lsr", name=args.wandb_run_name) wandb.watch(model, log="all") # Note: this will override the provided model weights if args.pretrained_embeddings_path is not None and not config.use_bert: pretrained_embeddings = np.load(args.pretrained_embeddings_path) model.load_pretrained_word_embedding( pretrained_word_embedding=pretrained_embeddings ) # Set to training device model.to(device=device) # Load dataset # Set to cpu initially (for preprocessing entire dataset first) logger.info("Preprocessing datasets...") rel2id_path = os.path.join(args.metadata_dir, "rel2id.json") word2id_path = os.path.join(args.metadata_dir, "word2id.json") ner2id_path = os.path.join(args.metadata_dir, "ner2id.json") train_preprocessor = LsrPreprocessor( rel2id_path=rel2id_path, word2id_path=word2id_path, ner2id_path=ner2id_path, is_train=True, device=torch.device("cpu"), config=config, ) val_preprocessor = LsrPreprocessor( rel2id_path=rel2id_path, word2id_path=word2id_path, ner2id_path=ner2id_path, device=torch.device("cpu"), config=config, ) train_dataset = DocredDataset( json_file=args.train_file, preprocessor=train_preprocessor ) val_dataset = DocredDataset( json_file=args.validation_file, preprocessor=val_preprocessor ) train_dataloader = DataLoader( train_dataset, shuffle=True, batch_size=args.batch_size, collate_fn=lsr_collate_fn, ) val_dataloader = DataLoader( val_dataset, shuffle=True, batch_size=args.batch_size, collate_fn=lsr_collate_fn ) # Optimizer and parameters if config.use_bert: other_params = [ p for name, p in model.named_parameters() if p.requires_grad and not name.startswith("bert") ] optimizer = torch.optim.Adam( [ {"params": other_params, "lr": args.lr}, {"params": model.bert.parameters(), "lr": 1e-5}, ], lr=args.lr, ) else: parameters = [p for p in model.parameters() if p.requires_grad] optimizer = get_optimizer(args.optim, parameters, args.lr) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, args.lr_decay) best_f1 = 0 for epoch in range(args.num_epoch): logger.info( f"Epoch: {epoch + 1}/{args.num_epoch}, lr: {optimizer.param_groups[0]['lr']}" ) train_metrics = LsrMetrics(num_relations=config.num_relations) total_epoch_train_loss = 0 model.train() for step, batch in enumerate(train_dataloader): for key, value in batch.items(): if isinstance(value, torch.Tensor): batch[key] = value.to(device=device) outputs = model(**batch) # Backpropagation loss = outputs.loss # TODO: Remove debug logs below if np.isnan(loss.item()): logger.info("Skipping backward prop.") continue optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() # Track metrics total_epoch_train_loss += loss.item() train_metrics.add_batch(model_output=outputs, batch_data=batch) # Compute metrics and log at end of epoch best_theta, f1, precision, recall, auc = train_metrics.compute() avg_epoch_train_loss = total_epoch_train_loss / (step + 1) logger.info( f"Train loss: {avg_epoch_train_loss:.3f}, best theta: {best_theta:.3f}, " f"f1: {f1:.3f}, precision: {precision:.3f}, recall: {recall:.3f}, auc: {auc:.5f}" ) if args.use_wandb: wandb.log( { "train_loss": avg_epoch_train_loss, "train_best_theta": best_theta, "train_f1": f1, "train_precision": precision, "train_recall": recall, "train_auc": auc, }, step=epoch, ) # Write train metrics if args.output_dir is not None: with open(train_metrics_file, "a") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writerow( { "epoch": epoch + 1, "loss": avg_epoch_train_loss, "best_theta": best_theta, "f1": f1, "precision": precision, "recall": recall, "auc": auc, } ) if epoch + 1 >= args.evaluate_epoch: val_metrics = LsrMetrics(num_relations=config.num_relations) total_epoch_val_loss = 0 model.eval() for step, batch in enumerate(val_dataloader): for key, value in batch.items(): if isinstance(value, torch.Tensor): batch[key] = value.to(device=device) outputs = model(**batch) # Track metrics total_epoch_val_loss += loss.item() val_metrics.add_batch(model_output=outputs, batch_data=batch) # Compute metrics and log at end of epoch best_theta, f1, precision, recall, auc = val_metrics.compute() avg_epoch_val_loss = total_epoch_val_loss / (step + 1) logger.info( f"Val loss: {avg_epoch_val_loss:.3f}, best theta: {best_theta:.3f}, " f"f1: {f1:.3f}, precision: {precision:.3f}, recall: {recall:.3f}, auc: {auc:.5f}" ) if args.use_wandb: wandb.log( { "val_loss": avg_epoch_val_loss, "val_best_theta": best_theta, "val_f1": f1, "val_precision": precision, "val_recall": recall, "val_auc": auc, }, step=epoch, ) # Write val metrics if args.output_dir is not None: with open(val_metrics_file, "a") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writerow( { "epoch": epoch + 1, "loss": avg_epoch_val_loss, "best_theta": best_theta, "f1": f1, "precision": precision, "recall": recall, "auc": auc, } ) # Save best model so far if args.output_dir is not None and f1 > best_f1: logger.info("Best f1 so far, saving model.") best_f1 = f1 model.save_pretrained(os.path.join(args.output_dir, "best_f1")) if epoch + 1 >= args.decay_epoch: if args.optim == "adam" and optimizer.param_groups[0]["lr"] > 1e-4: scheduler.step() # Save final model if args.output_dir is not None: logger.info("Saving final model.") model.save_pretrained(os.path.join(args.output_dir, "final"))
if __name__ == "__main__": args = parse_args() train_model(args)