Source code for sgnlp.models.emotion_entailment.data_class

from dataclasses import dataclass, field
from typing import Dict, Any


[docs]@dataclass class RecconEmotionEntailmentArguments: model_name: str = field( default="roberta-base", metadata={"help": "Pretrained model to use for training"}, ) x_train_path: str = field( default="data/dailydialog_classification_train_with_context.csv", metadata={"help": "Path of training data"}, ) x_valid_path: str = field( default="data/dailydialog_classification_valid_with_context.csv", metadata={"help": "Path of validation data"}, ) batch_size: int = field( default=8, metadata={"help": "Batch size for training"}, ) max_seq_length: int = field( default=512, metadata={"help": "Maximum sequence length"}, ) train_args: Dict[str, Any] = field( default_factory=lambda: { "output_dir": "output/", "overwrite_output_dir": True, "evaluation_strategy": "steps", "per_device_train_batch_size": 8, "per_device_eval_batch_size": 8, "gradient_accumulation_steps": 1, "learning_rate": 1e-5, "weight_decay": 0, "adam_epsilon": 1e-8, "max_grad_norm": 1, "num_train_epochs": 12, "lr_scheduler_type": "linear", "warmup_ratio": 0.06, "no_cuda": False, "seed": 0, "fp16": False, "load_best_model_at_end": True, "report_to": "none", }, metadata={"help": "Arguments for training Reccon Emotion Entailment models."}, ) eval_args: Dict[str, Any] = field( default_factory=lambda: { "trained_model_dir": "output/", "x_test_path": "data/dailydialog_classification_test_with_context.csv", "results_path": "output/classification_result.txt", "per_device_eval_batch_size": 8, "no_cuda": False, }, metadata={"help": "Arguments for evaluating Reccon Emotion Entailment models."}, ) def __post_init__(self): # Model assert self.model_name in [ "roberta-base", "roberta-large", ], "Invalid model type!" # Training assert self.batch_size > 0, "batch_size must be positive!" assert self.max_seq_length > 0, "max_seq_length must be positive!" assert isinstance( self.train_args, dict ), "train_args must be represented as a Dictionary." assert self.train_args["seed"] >= 0, "Random seed must be positive!" assert ( self.train_args["num_train_epochs"] > 0 ), "num_train_epochs must be at least 1." assert ( self.train_args["per_device_train_batch_size"] > 0 ), "per_device_train_batch_size must be at least 1." assert ( self.train_args["per_device_eval_batch_size"] > 0 ), "per_device_eval_batch_size must be at least 1." assert self.train_args["learning_rate"] > 0, "learning_rate must be positive." assert self.train_args["warmup_ratio"] >= 0, "warmup_ratio must be at least 0." assert self.train_args["max_grad_norm"] > 0, "max_grad_norm must be positive." # Eval assert isinstance( self.eval_args, dict ), "eval_args must be represented as a Dictionary." assert ( self.train_args["per_device_eval_batch_size"] > 0 ), "per_device_eval_batch_size must be at least 1."