Source code for sgnlp.models.span_extraction.train

import json
import math

from transformers import Trainer
from transformers import TrainingArguments

from .config import RecconSpanExtractionConfig
from .data_class import RecconSpanExtractionArguments
from .modeling import RecconSpanExtractionModel
from .tokenization import RecconSpanExtractionTokenizer
from .utils import parse_args_and_load_config, load_examples, RecconSpanExtractionData


[docs]def train_model(cfg: RecconSpanExtractionArguments): """ Method for training RecconSpanExtractionModel. Args: config (:obj:`RecconSpanExtractionArguments`): RecconSpanExtractionArguments config load from config file. Example:: import json from sgnlp.models.span_extraction import train from sgnlp.models.span_extraction.utils import parse_args_and_load_config cfg = parse_args_and_load_config('config/span_extraction_config.json') train(cfg) """ config = RecconSpanExtractionConfig.from_pretrained(cfg.model_name) tokenizer = RecconSpanExtractionTokenizer.from_pretrained(cfg.model_name) model = RecconSpanExtractionModel.from_pretrained(cfg.model_name, config=config) with open(cfg.train_data_path, "r") as train_file: train_json = json.load(train_file) with open(cfg.val_data_path, "r") as val_file: val_json = json.load(val_file) load_train_exp_args = { "examples": train_json, "tokenizer": tokenizer, "max_seq_length": cfg.max_seq_length, "doc_stride": cfg.doc_stride, "max_query_length": cfg.max_query_length, } load_valid_exp_args = { "examples": val_json, "tokenizer": tokenizer, "max_seq_length": cfg.max_seq_length, "doc_stride": cfg.doc_stride, "max_query_length": cfg.max_query_length, } train_dataset = load_examples(**load_train_exp_args) val_dataset = load_examples(**load_valid_exp_args) t_total = ( len(train_dataset) // cfg.train_args["gradient_accumulation_steps"] * cfg.train_args["num_train_epochs"] ) cfg.train_args["eval_steps"] = int( len(train_dataset) / cfg.train_args["per_device_train_batch_size"] ) cfg.train_args["warmup_steps"] = math.ceil(t_total * cfg.train_args["warmup_ratio"]) training_args = TrainingArguments(**cfg.train_args) trainer = Trainer( model=model, args=training_args, train_dataset=RecconSpanExtractionData(train_dataset), eval_dataset=RecconSpanExtractionData(val_dataset), ) trainer.train() trainer.save_model()
if __name__ == "__main__": cfg = parse_args_and_load_config() train_model(cfg)