Source code for sgnlp.models.span_extraction.data_class

from dataclasses import dataclass, field
from typing import Dict, Any


[docs]@dataclass class RecconSpanExtractionArguments: model_name: str = field( default="mrm8488/spanbert-finetuned-squadv2", metadata={"help": "Pretrained model to use for training"}, ) train_data_path: str = field( default="data/subtask1/fold1/dailydialog_qa_train_with_context.json", metadata={"help": "Path of training data"}, ) val_data_path: str = field( default="data/subtask1/fold1/dailydialog_qa_valid_with_context.json", metadata={"help": "Path of validation data"}, ) test_data_path: str = field( default="data/subtask1/fold1/dailydialog_qa_test_with_context.json", metadata={"help": "Path of validation data"}, ) max_seq_length: int = field( default=512, metadata={"help": "Maximum sequence length"}, ) doc_stride: int = field( default=512, metadata={"help": "Document stride"}, ) max_query_length: int = field( default=512, metadata={"help": "Maximum query length"}, ) train_args: Dict[str, Any] = field( default_factory=lambda: { "output_dir": "output/", "overwrite_output_dir": True, "evaluation_strategy": "steps", "per_device_train_batch_size": 16, "per_device_eval_batch_size": 16, "gradient_accumulation_steps": 1, "learning_rate": 1e-5, "weight_decay": 0, "adam_epsilon": 1e-8, "max_grad_norm": 1, "num_train_epochs": 12, "warmup_ratio": 0.06, "no_cuda": False, "seed": 0, "fp16": False, "load_best_model_at_end": True, "label_names": ["start_positions", "end_positions"], "report_to": "none", }, metadata={"help": "Arguments for training Reccon Span Extraction models."}, ) eval_args: Dict[str, Any] = field( default_factory=lambda: { "trained_model_dir": "output/", "results_path": "result/", "batch_size": 16, "n_best_size": 20, "null_score_diff_threshold": 0.0, "sliding_window": False, "no_cuda": False, "max_answer_length": 200, }, metadata={"help": "Arguments for evaluating Reccon Span Extraction models."}, ) def __post_init__(self): # Model assert self.model_name in [ "mrm8488/spanbert-finetuned-squadv2", "roberta-base", ], "Invalid model type!" # Training assert self.max_seq_length > 0, "max_seq_length must be positive." assert self.doc_stride > 0, "doc_stride must be positive." assert self.max_query_length > 0, "max_query_length must be positive." assert isinstance( self.train_args, Dict ), "train_args must be represented as a Dictionary." assert self.train_args["seed"] >= 0, "Random seed must be at least 0." assert ( self.train_args["num_train_epochs"] > 0 ), "num_train_epochs must be at least 1." assert ( self.train_args["per_device_train_batch_size"] > 0 ), "per_device_train_batch_size must be at least 1." assert ( self.train_args["per_device_eval_batch_size"] > 0 ), "per_device_eval_batch_size must be at least 1." assert ( self.train_args["gradient_accumulation_steps"] > 0 ), "gradient_accumulation_steps must be positive." assert self.train_args["learning_rate"] > 0, "learning_rate must be positive." assert self.train_args["warmup_ratio"] >= 0, "warmup_ratio must be at least 0." assert self.train_args["weight_decay"] >= 0, "weight_decay must be at least 0." assert self.train_args["max_grad_norm"] > 0, "max_grad_norm must be positive." assert self.train_args["adam_epsilon"] >= 0, "adam_epsilon must be at least 0." # Eval assert isinstance( self.eval_args, Dict ), "eval_args must be represented as a Dictionary." assert self.eval_args["n_best_size"] >= 1, "n_best_size must be at least 1." assert ( self.eval_args["null_score_diff_threshold"] >= 0 ), "null_score_diff_threshold must be at least 0." assert ( self.eval_args["max_answer_length"] >= 1 ), "max_answer_length must be at least 1."