import copy
import logging
import os
import pickle
import random
from typing import List
import numpy as np
import torch
from sgnlp.utils.csv_writer import CsvWriter
from .data_class import RstPointerParserTrainArgs, RstPointerSegmenterTrainArgs
from .modeling import (
RstPointerParserModel,
RstPointerParserConfig,
RstPointerSegmenterModel,
RstPointerSegmenterConfig,
)
from .modules.type import DiscourseTreeNode, DiscourseTreeSplit
from .preprocess import RstPreprocessor
from .utils import parse_args_and_load_config
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Shared functions
[docs]def setup(seed):
# Set seeds
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
# See: https://pytorch.org/docs/stable/notes/randomness.html
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
[docs]def adjust_learning_rate(optimizer, epoch, lr_decay=0.5, lr_decay_epoch=50):
if (epoch % lr_decay_epoch == 0) and (epoch != 0):
for param_group in optimizer.param_groups:
param_group["lr"] = param_group["lr"] * lr_decay
# Parser training code
[docs]def get_span_dict(discourse_tree_splits: List[DiscourseTreeSplit]):
span_dict = {}
for split in discourse_tree_splits:
left_span_key, left_node_value = get_span_key_and_node_value(split.left)
span_dict[left_span_key] = left_node_value
right_span_key, right_node_value = get_span_key_and_node_value(split.right)
span_dict[right_span_key] = right_node_value
return span_dict
[docs]def get_span_key_and_node_value(node: DiscourseTreeNode):
span_key = f"{node.span[0]}-{node.span[1]}"
node_value = [node.label, node.ns_type]
return span_key, node_value
[docs]def get_measurement(discourse_tree_splits_1, discourse_tree_splits_2):
span_dict1 = get_span_dict(discourse_tree_splits_1)
span_dict2 = get_span_dict(discourse_tree_splits_2)
num_correct_relations = 0
num_correct_nuclearity = 0
num_spans_1 = len(span_dict1)
num_spans_2 = len(span_dict2)
# no of right spans
matching_spans = list(set(span_dict1.keys()).intersection(set(span_dict2.keys())))
num_matching_spans = len(matching_spans)
# count matching relations and nuclearity
for span in matching_spans:
if span_dict1[span][0] == span_dict2[span][0]:
num_correct_relations += 1
if span_dict1[span][1] == span_dict2[span][1]:
num_correct_nuclearity += 1
return (
num_matching_spans,
num_correct_relations,
num_correct_nuclearity,
num_spans_1,
num_spans_2,
)
[docs]def get_batch_measure(input_splits_batch, golden_metric_batch):
num_matching_spans = 0
num_correct_relations = 0
num_correct_nuclearity = 0
num_spans_input = 0
num_spans_golden = 0
for input_splits, golden_splits in zip(input_splits_batch, golden_metric_batch):
if input_splits and golden_splits:
# if both splits have values in the list
(
_num_matching_spans,
_num_correct_relations,
_num_correct_nuclearity,
_num_spans_input,
_num_spans_golden,
) = get_measurement(input_splits, golden_splits)
num_matching_spans += _num_matching_spans
num_correct_relations += _num_correct_relations
num_correct_nuclearity += _num_correct_nuclearity
num_spans_input += _num_spans_input
num_spans_golden += _num_spans_golden
elif input_splits and not golden_splits:
# each split has 2 spans
num_spans_input += len(input_splits) * 2
elif not input_splits and golden_splits:
num_spans_golden += len(golden_splits) * 2
return (
num_matching_spans,
num_correct_relations,
num_correct_nuclearity,
num_spans_input,
num_spans_golden,
)
[docs]def get_micro_measure(
correct_span, correct_relation, correct_nuclearity, no_system, no_golden
):
# Compute Micro-average measure
# Span
precision_span = correct_span / no_system
recall_span = correct_span / no_golden
f1_span = (2 * correct_span) / (no_golden + no_system)
# Relation
precision_relation = correct_relation / no_system
recall_relation = correct_relation / no_golden
f1_relation = (2 * correct_relation) / (no_golden + no_system)
# Nuclearity
precision_nuclearity = correct_nuclearity / no_system
recall_nuclearity = correct_nuclearity / no_golden
f1_nuclearity = (2 * correct_nuclearity) / (no_golden + no_system)
return (
(precision_span, recall_span, f1_span),
(precision_relation, recall_relation, f1_relation),
(precision_nuclearity, recall_nuclearity, f1_nuclearity),
)
# TODO: Change data sampling
[docs]def get_batch_data_training(
input_sentences,
edu_breaks,
decoder_input,
relation_label,
parsing_breaks,
golden_metric,
parents_index,
sibling,
batch_size,
):
# change them into np.array
input_sentences = np.array(input_sentences, dtype="object")
edu_breaks = np.array(edu_breaks, dtype="object")
decoder_input = np.array(decoder_input, dtype="object")
relation_label = np.array(relation_label, dtype="object")
parsing_breaks = np.array(parsing_breaks, dtype="object")
golden_metric = np.array(golden_metric, dtype="object")
parents_index = np.array(parents_index, dtype="object")
sibling = np.array(sibling, dtype="object")
if len(decoder_input) < batch_size:
batch_size = len(decoder_input)
sample_indices = random.sample(range(len(decoder_input)), batch_size)
# Get batch data
input_sentences_batch = copy.deepcopy(input_sentences[sample_indices])
edu_breaks_batch = copy.deepcopy(edu_breaks[sample_indices])
decoder_input_batch = copy.deepcopy(decoder_input[sample_indices])
relation_label_batch = copy.deepcopy(relation_label[sample_indices])
parsing_breaks_batch = copy.deepcopy(parsing_breaks[sample_indices])
golden_metric_batch = copy.deepcopy(golden_metric[sample_indices])
parents_index_batch = copy.deepcopy(parents_index[sample_indices])
sibling_batch = copy.deepcopy(sibling[sample_indices])
# Get sorted
lengths_batch = np.array([len(sent) for sent in input_sentences_batch])
idx = np.argsort(lengths_batch)
idx = idx[::-1]
# Convert them back to list
input_sentences_batch = input_sentences_batch[idx].tolist()
edu_breaks_batch = edu_breaks_batch[idx].tolist()
decoder_input_batch = decoder_input_batch[idx].tolist()
relation_label_batch = relation_label_batch[idx].tolist()
parsing_breaks_batch = parsing_breaks_batch[idx].tolist()
golden_metric_batch = golden_metric_batch[idx].tolist()
parents_index_batch = parents_index_batch[idx].tolist()
sibling_batch = sibling_batch[idx].tolist()
return (
input_sentences_batch,
edu_breaks_batch,
decoder_input_batch,
relation_label_batch,
parsing_breaks_batch,
golden_metric_batch,
parents_index_batch,
sibling_batch,
)
[docs]def get_batch_data(
input_sentences,
edu_breaks,
decoder_input,
relation_label,
parsing_breaks,
golden_metric,
batch_size,
):
# change them into np.array
input_sentences = np.array(input_sentences, dtype="object")
edu_breaks = np.array(edu_breaks, dtype="object")
decoder_input = np.array(decoder_input, dtype="object")
relation_label = np.array(relation_label, dtype="object")
parsing_breaks = np.array(parsing_breaks, dtype="object")
golden_metric = np.array(golden_metric, dtype="object")
if len(decoder_input) < batch_size:
batch_size = len(decoder_input)
sample_indices = random.sample(range(len(decoder_input)), batch_size)
# Get batch data
input_sentences_batch = copy.deepcopy(input_sentences[sample_indices])
edu_breaks_batch = copy.deepcopy(edu_breaks[sample_indices])
decoder_input_batch = copy.deepcopy(decoder_input[sample_indices])
relation_label_batch = copy.deepcopy(relation_label[sample_indices])
parsing_breaks_batch = copy.deepcopy(parsing_breaks[sample_indices])
golden_metric_batch = copy.deepcopy(golden_metric[sample_indices])
# Get sorted
lengths_batch = np.array([len(sent) for sent in input_sentences_batch])
idx = np.argsort(lengths_batch)
idx = idx[::-1]
# Convert them back to list
input_sentences_batch = input_sentences_batch[idx].tolist()
edu_breaks_batch = edu_breaks_batch[idx].tolist()
decoder_input_batch = decoder_input_batch[idx].tolist()
relation_label_batch = relation_label_batch[idx].tolist()
parsing_breaks_batch = parsing_breaks_batch[idx].tolist()
golden_metric_batch = golden_metric_batch[idx].tolist()
return (
input_sentences_batch,
edu_breaks_batch,
decoder_input_batch,
relation_label_batch,
parsing_breaks_batch,
golden_metric_batch,
)
[docs]def get_accuracy(
model,
preprocessor,
input_sentences,
edu_breaks,
decoder_input,
relation_label,
parsing_breaks,
golden_metric,
batch_size,
):
num_loops = int(np.ceil(len(edu_breaks) / batch_size))
loss_tree_all = []
loss_label_all = []
correct_span = 0
correct_relation = 0
correct_nuclearity = 0
no_system = 0
no_golden = 0
for loop in range(num_loops):
start_idx = loop * batch_size
end_idx = (loop + 1) * batch_size
if end_idx > len(edu_breaks):
end_idx = len(edu_breaks)
(
input_sentences_batch,
edu_breaks_batch,
_,
relation_label_batch,
parsing_breaks_batch,
golden_metric_splits_batch,
) = get_batch_data(
input_sentences[start_idx:end_idx],
edu_breaks[start_idx:end_idx],
decoder_input[start_idx:end_idx],
relation_label[start_idx:end_idx],
parsing_breaks[start_idx:end_idx],
golden_metric[start_idx:end_idx],
batch_size,
)
input_sentences_ids_batch, sentence_lengths = preprocessor.get_elmo_char_ids(
input_sentences_batch
)
input_sentences_ids_batch = input_sentences_ids_batch.to(device=model.device)
model_output = model(
input_sentence_ids=input_sentences_ids_batch,
edu_breaks=edu_breaks_batch,
sentence_lengths=sentence_lengths,
label_index=relation_label_batch,
parsing_index=parsing_breaks_batch,
generate_splits=True,
)
loss_tree_all.append(model_output.loss_tree)
loss_label_all.append(model_output.loss_label)
(
correct_span_batch,
correct_relation_batch,
correct_nuclearity_batch,
no_system_batch,
no_golden_batch,
) = get_batch_measure(model_output.splits, golden_metric_splits_batch)
correct_span = correct_span + correct_span_batch
correct_relation = correct_relation + correct_relation_batch
correct_nuclearity = correct_nuclearity + correct_nuclearity_batch
no_system = no_system + no_system_batch
no_golden = no_golden + no_golden_batch
span_points, relation_points, nuclearity_points = get_micro_measure(
correct_span, correct_relation, correct_nuclearity, no_system, no_golden
)
return (
np.mean(loss_tree_all),
np.mean(loss_label_all),
span_points,
relation_points,
nuclearity_points,
)
[docs]def train_parser(cfg: RstPointerParserTrainArgs) -> None:
logger.info(f"===== Training RST Pointer Parser =====")
# Setup
setup(seed=cfg.seed)
train_data_dir = cfg.train_data_dir
test_data_dir = cfg.test_data_dir
save_dir = cfg.save_dir
batch_size = cfg.batch_size
hidden_size = cfg.hidden_size
rnn_layers = cfg.num_rnn_layers
dropout_e = cfg.dropout_e
dropout_d = cfg.dropout_d
dropout_c = cfg.dropout_c
atten_model = cfg.atten_model
classifier_input_size = cfg.classifier_input_size
classifier_hidden_size = cfg.classifier_hidden_size
classifier_bias = cfg.classifier_bias
elmo_size = cfg.elmo_size
epochs = cfg.epochs
lr = cfg.lr
lr_decay_epoch = cfg.lr_decay_epoch
weight_decay = cfg.weight_decay
highorder = cfg.highorder
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Device: {device}")
# Create directory and files
os.makedirs(save_dir, exist_ok=True)
best_results_writer = CsvWriter(
file_path=os.path.join(save_dir, "best_results.csv"),
fieldnames=[
"best_epoch",
"f1_relation",
"precision_relation",
"recall_relation",
"f1_span",
"precision_span",
"recall_span",
"f1_nuclearity",
"precision_nuclearity",
"recall_nuclearity",
],
)
results_writer = CsvWriter(
file_path=os.path.join(save_dir, "results.csv"),
fieldnames=[
"current_epoch",
"loss_tree_test",
"loss_label_test",
"f1_span",
"f1_relation",
"f1_nuclearity",
],
)
logger.info("Loading training and test data...")
# Load Training data
tr_input_sentences = pickle.load(
open(os.path.join(train_data_dir, "tokenized_sentences.pickle"), "rb")
)
tr_edu_breaks = pickle.load(
open(os.path.join(train_data_dir, "edu_breaks.pickle"), "rb")
)
tr_decoder_input = pickle.load(
open(os.path.join(train_data_dir, "decoder_input_index.pickle"), "rb")
)
tr_relation_label = pickle.load(
open(os.path.join(train_data_dir, "relation_index.pickle"), "rb")
)
tr_parsing_breaks = pickle.load(
open(os.path.join(train_data_dir, "splits_order.pickle"), "rb")
)
tr_golden_metric = pickle.load(
open(os.path.join(train_data_dir, "discourse_tree_splits.pickle"), "rb")
)
tr_parents_index = pickle.load(
open(os.path.join(train_data_dir, "parent_index.pickle"), "rb")
)
tr_sibling_index = pickle.load(
open(os.path.join(train_data_dir, "sibling_index.pickle"), "rb")
)
# Load Testing data
test_input_sentences = pickle.load(
open(os.path.join(test_data_dir, "tokenized_sentences.pickle"), "rb")
)
test_edu_breaks = pickle.load(
open(os.path.join(test_data_dir, "edu_breaks.pickle"), "rb")
)
test_decoder_input = pickle.load(
open(os.path.join(test_data_dir, "decoder_input_index.pickle"), "rb")
)
test_relation_label = pickle.load(
open(os.path.join(test_data_dir, "relation_index.pickle"), "rb")
)
test_parsing_breaks = pickle.load(
open(os.path.join(test_data_dir, "splits_order.pickle"), "rb")
)
test_golden_metric = pickle.load(
open(os.path.join(test_data_dir, "discourse_tree_splits.pickle"), "rb")
)
logger.info("--------------------------------------------------------------------")
logger.info("Starting model training...")
logger.info("--------------------------------------------------------------------")
# Initialize model
model_config = RstPointerParserConfig(
hidden_size=hidden_size,
decoder_input_size=hidden_size,
atten_model=atten_model,
classifier_input_size=classifier_input_size,
classifier_hidden_size=classifier_hidden_size,
highorder=highorder,
classes_label=39,
classifier_bias=classifier_bias,
rnn_layers=rnn_layers,
dropout_e=dropout_e,
dropout_d=dropout_d,
dropout_c=dropout_c,
elmo_size=elmo_size,
)
model = RstPointerParserModel(model_config)
model = model.to(device)
preprocessor = RstPreprocessor()
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=lr,
betas=(0.9, 0.9),
weight_decay=weight_decay,
)
num_iterations = int(np.ceil(len(tr_parsing_breaks) / batch_size))
best_f1_relation = 0
best_f1_span = 0
for current_epoch in range(epochs):
adjust_learning_rate(optimizer, current_epoch, 0.8, lr_decay_epoch)
for current_iteration in range(num_iterations):
(
input_sentences_batch,
edu_breaks_batch,
decoder_input_batch,
relation_label_batch,
parsing_breaks_batch,
_,
parents_index_batch,
sibling_batch,
) = get_batch_data_training(
tr_input_sentences,
tr_edu_breaks,
tr_decoder_input,
tr_relation_label,
tr_parsing_breaks,
tr_golden_metric,
tr_parents_index,
tr_sibling_index,
batch_size,
)
model.zero_grad()
(
input_sentences_ids_batch,
sentence_lengths,
) = preprocessor.get_elmo_char_ids(input_sentences_batch)
input_sentences_ids_batch = input_sentences_ids_batch.to(
device=model.device
)
loss_tree_batch, loss_label_batch = model.forward_train(
input_sentence_ids_batch=input_sentences_ids_batch,
edu_breaks_batch=edu_breaks_batch,
label_index_batch=relation_label_batch,
parsing_index_batch=parsing_breaks_batch,
decoder_input_index_batch=decoder_input_batch,
parents_index_batch=parents_index_batch,
sibling_index_batch=sibling_batch,
sentence_lengths=sentence_lengths,
)
loss = loss_tree_batch + loss_label_batch
loss.backward()
cur_loss = float(loss.item())
logger.info(
f"Epoch: {current_epoch + 1}/{epochs}, "
f"iteration: {current_iteration + 1}/{num_iterations}, "
f"loss: {cur_loss:.3f}"
)
# To avoid gradient explosion
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
optimizer.step()
# Convert model to eval
model.eval()
# Eval on Testing data
(
loss_tree_test,
loss_label_test,
span_points_test,
relation_points_test,
nuclearity_points_test,
) = get_accuracy(
model,
preprocessor,
test_input_sentences,
test_edu_breaks,
test_decoder_input,
test_relation_label,
test_parsing_breaks,
test_golden_metric,
batch_size,
)
# Unfold numbers
# Test
precision_span, recall_span, f1_span = span_points_test
precision_relation, recall_relation, f1_relation = relation_points_test
precision_nuclearity, recall_nuclearity, f1_nuclearity = nuclearity_points_test
# Relation will take the priority consideration
if f1_relation > best_f1_relation:
best_epoch = current_epoch
# relation
best_f1_relation = f1_relation
best_precision_relation = precision_relation
best_recall_relation = recall_relation
# span
best_f1_span = f1_span
best_precision_span = precision_span
best_recall_span = recall_span
# nuclearity
best_f1_nuclearity = f1_nuclearity
best_precision_nuclearity = precision_nuclearity
best_recall_nuclearity = recall_nuclearity
# Log evaluation and test metrics
epoch_metrics = {
"current_epoch": current_epoch,
"loss_tree_test": loss_tree_test,
"loss_label_test": loss_label_test,
"f1_span": f1_span,
"f1_relation": f1_relation,
"f1_nuclearity": f1_nuclearity,
}
logger.info(f"Test metrics: {epoch_metrics}")
results_writer.writerow(epoch_metrics)
# Saving model
if best_epoch == current_epoch:
model.save_pretrained(save_dir)
# Convert back to training
model.train()
logger.info("--------------------------------------------------------------------")
logger.info("Model training completed!")
logger.info("--------------------------------------------------------------------")
logger.info(f"The best F1 points for Relation is: {best_f1_relation:.3f}.")
logger.info(f"The best F1 points for Nuclearity is: {best_f1_nuclearity:.3f}")
logger.info(f"The best F1 points for Span is: {best_f1_span:.3f}")
best_results_writer.writerow(
{
"best_epoch": best_epoch,
"f1_relation": best_f1_relation,
"precision_relation": best_precision_relation,
"recall_relation": best_recall_relation,
"f1_span": best_f1_span,
"precision_span": best_precision_span,
"recall_span": best_recall_span,
"f1_nuclearity": best_f1_nuclearity,
"precision_nuclearity": best_precision_nuclearity,
"recall_nuclearity": best_recall_nuclearity,
}
)
# Segmenter training code
[docs]def sample_a_sorted_batch_from_numpy(input_x, output_y, batch_size):
input_x = np.array(input_x, dtype="object")
output_y = np.array(output_y, dtype="object")
if batch_size is not None:
select_index = random.sample(range(len(output_y)), batch_size)
else:
select_index = np.array(range(len(output_y)))
batch_x = copy.deepcopy(input_x[select_index])
batch_y = copy.deepcopy(output_y[select_index])
all_lens = np.array([len(x) for x in batch_x])
idx = np.argsort(all_lens)
idx = idx[::-1] # decreasing
batch_x = batch_x[idx]
batch_y = batch_y[idx]
# decoder input
batch_x_index = []
for i in range(len(batch_y)):
cur_y = batch_y[i]
temp = [x + 1 for x in cur_y]
temp.insert(0, 0)
temp.pop()
batch_x_index.append(temp)
all_lens = all_lens[idx]
return batch_x, batch_x_index, batch_y, all_lens
[docs]def get_batch_test(x, y, batch_size):
x = np.array(x, dtype="object")
y = np.array(y, dtype="object")
if batch_size is not None:
select_index = random.sample(range(len(y)), batch_size)
else:
select_index = np.array(range(len(y)))
batch_x = copy.deepcopy(x[select_index])
batch_y = copy.deepcopy(y[select_index])
all_lens = np.array([len(x) for x in batch_x])
return batch_x, batch_y, all_lens
[docs]def sample_batch(x, y, sample_size):
select_index = random.sample(range(len(y)), sample_size)
x = np.array(x, dtype="object")
y = np.array(y, dtype="object")
return x[select_index], y[select_index]
[docs]def get_batch_micro_metric(pre_b, ground_b):
all_c = []
all_r = []
all_g = []
for i in range(len(ground_b)):
index_of_1 = np.array(ground_b[i])
index_pre = pre_b[i]
index_pre = np.array(index_pre)
end_b = index_of_1[-1]
index_pre = index_pre[index_pre != end_b]
index_of_1 = index_of_1[index_of_1 != end_b]
no_correct = len(np.intersect1d(list(index_of_1), list(index_pre)))
all_c.append(no_correct)
all_r.append(len(index_pre))
all_g.append(len(index_of_1))
return all_c, all_r, all_g
# Unused
[docs]def get_batch_metric(pre_b, ground_b):
b_pr = []
b_re = []
b_f1 = []
for i, cur_seq_y in enumerate(ground_b):
index_of_1 = np.where(cur_seq_y == 1)[0]
index_pre = pre_b[i]
no_correct = len(np.intersect1d(index_of_1, index_pre))
cur_pre = no_correct / len(index_pre)
cur_rec = no_correct / len(index_of_1)
cur_f1 = 2 * cur_pre * cur_rec / (cur_pre + cur_rec)
b_pr.append(cur_pre)
b_re.append(cur_rec)
b_f1.append(cur_f1)
return b_pr, b_re, b_f1
[docs]def check_accuracy(model, preprocessor, x, y, batch_size):
num_loops = int(np.ceil(len(y) / batch_size))
all_ave_loss = []
all_start_boundaries = []
all_end_boundaries = []
all_index_decoder_y = []
all_x_save = []
all_c = []
all_r = []
all_g = []
for i in range(num_loops):
start_idx = i * batch_size
end_idx = (i + 1) * batch_size
if end_idx > len(y):
end_idx = len(y)
batch_x, batch_y, all_lens = get_batch_test(
x[start_idx:end_idx], y[start_idx:end_idx], None
)
input_sentences_ids_batch, _ = preprocessor.get_elmo_char_ids(batch_x)
input_sentences_ids_batch = input_sentences_ids_batch.to(device=model.device)
output = model(input_sentences_ids_batch, all_lens, batch_y)
batch_ave_loss = output.loss
batch_start_boundaries = output.start_boundaries
batch_end_boundaries = output.end_boundaries
all_ave_loss.extend([batch_ave_loss.cpu().data.numpy()])
all_start_boundaries.extend(batch_start_boundaries)
all_end_boundaries.extend(batch_end_boundaries)
ba_c, ba_r, ba_g = get_batch_micro_metric(batch_end_boundaries, batch_y)
all_c.extend(ba_c)
all_r.extend(ba_r)
all_g.extend(ba_g)
ba_pre = np.sum(all_c) / np.sum(all_r)
ba_rec = np.sum(all_c) / np.sum(all_g)
ba_f1 = 2 * ba_pre * ba_rec / (ba_pre + ba_rec)
return (
np.mean(all_ave_loss),
ba_pre,
ba_rec,
ba_f1,
(all_x_save, all_index_decoder_y, all_start_boundaries, all_end_boundaries),
)
[docs]def train_segmenter(cfg: RstPointerSegmenterTrainArgs) -> None:
logger.info(f"===== Training RST Pointer Segmenter =====")
setup(seed=cfg.seed)
train_data_dir = cfg.train_data_dir
test_data_dir = cfg.test_data_dir
save_dir = cfg.save_dir
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Device: {device}")
hidden_dim = cfg.hidden_dim
rnn_type = cfg.rnn
num_rnn_layers = cfg.num_rnn_layers
lr = cfg.lr
dropout = cfg.dropout
wd = cfg.weight_decay
batch_size = cfg.batch_size
lr_decay_epoch = cfg.lr_decay_epoch
elmo_size = cfg.elmo_size
epochs = cfg.epochs
use_bilstm = cfg.use_bilstm
is_batch_norm = cfg.use_batch_norm
tr_x = pickle.load(
open(os.path.join(train_data_dir, "tokenized_sentences.pickle"), "rb")
)
tr_y = pickle.load(open(os.path.join(train_data_dir, "edu_breaks.pickle"), "rb"))
dev_x = pickle.load(
open(os.path.join(test_data_dir, "tokenized_sentences.pickle"), "rb")
)
dev_y = pickle.load(open(os.path.join(test_data_dir, "edu_breaks.pickle"), "rb"))
model_config = RstPointerSegmenterConfig(
hidden_dim=hidden_dim,
dropout_prob=dropout,
use_bilstm=use_bilstm,
num_rnn_layers=num_rnn_layers,
rnn_type=rnn_type,
is_batch_norm=is_batch_norm,
elmo_size=elmo_size,
)
model = RstPointerSegmenterModel(model_config)
model.to(device=device)
preprocessor = RstPreprocessor()
# Arbitrary eval_size
eval_size = len(dev_x) * 2 // 3
test_train_x, test_train_y = sample_batch(tr_x, tr_y, eval_size)
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd
)
num_iterations = int(np.round(len(tr_y) / batch_size))
os.makedirs(save_dir, exist_ok=True)
best_results_writer = CsvWriter(
file_path=os.path.join(save_dir, "best_results.csv"),
fieldnames=["best_epoch", "precision", "recall", "f1"],
)
results_writer = CsvWriter(
file_path=os.path.join(save_dir, "results.csv"),
fieldnames=[
"current_epoch",
"train_loss",
"train_precision",
"train_recall",
"train_f1",
"dev_loss",
"dev_precision",
"dev_recall",
"dev_f1",
],
)
best_epoch = 0
best_f1 = 0
for current_epoch in range(epochs):
adjust_learning_rate(optimizer, current_epoch, 0.8, lr_decay_epoch)
track_epoch_loss = []
for current_iter in range(num_iterations):
(
batch_x,
batch_x_index,
batch_y,
all_lens,
) = sample_a_sorted_batch_from_numpy(tr_x, tr_y, batch_size)
model.zero_grad()
input_sentences_ids_batch, _ = preprocessor.get_elmo_char_ids(batch_x)
input_sentences_ids_batch = input_sentences_ids_batch.to(
device=model.device
)
output = model(input_sentences_ids_batch, all_lens, batch_y)
loss = output.loss
loss_value = float(loss.data)
track_epoch_loss.append(loss_value)
logger.info(
f"Epoch: {current_epoch + 1}/{epochs}, "
f"iteration: {current_iter + 1}/{num_iterations}, "
f"loss: {loss_value:.3f}"
)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
optimizer.step()
model.eval()
logger.info(
"Running end of epoch evaluations on sample train data and test data..."
)
tr_batch_ave_loss, tr_pre, tr_rec, tr_f1, tr_visdata = check_accuracy(
model, preprocessor, test_train_x, test_train_y, batch_size
)
dev_batch_ave_loss, dev_pre, dev_rec, dev_f1, dev_visdata = check_accuracy(
model, preprocessor, dev_x, dev_y, batch_size
)
_, _, _, all_end_boundaries = dev_visdata
logger.info(
f"train sample -- loss: {tr_batch_ave_loss:.3f}, "
f"precision: {tr_pre:.3f}, recall: {tr_rec:.3f}, f1: {tr_f1:.3f}"
)
logger.info(
f"test sample -- loss: {dev_batch_ave_loss:.3f}, "
f"precision: {dev_pre:.3f}, recall: {dev_rec:.3f}, f1: {dev_f1:.3f}"
)
if best_f1 < dev_f1:
best_f1 = dev_f1
best_rec = dev_rec
best_pre = dev_pre
best_epoch = current_epoch
results_writer.writerow(
{
"current_epoch": current_epoch,
"train_loss": tr_batch_ave_loss,
"train_precision": tr_pre,
"train_recall": tr_rec,
"train_f1": tr_f1,
"dev_loss": dev_batch_ave_loss,
"dev_precision": dev_pre,
"dev_recall": dev_rec,
"dev_f1": dev_f1,
}
)
if current_epoch == best_epoch:
logger.info("Saving best model...")
model.save_pretrained(save_dir)
with open(os.path.join(save_dir, "best_segmentation.pickle"), "wb") as f:
pickle.dump(all_end_boundaries, f)
model.train()
best_results_writer.writerow(
{
"best_epoch": best_epoch,
"precision": best_pre,
"recall": best_rec,
"f1": best_f1,
}
)
if __name__ == "__main__":
cfg = parse_args_and_load_config()
if isinstance(cfg, RstPointerSegmenterTrainArgs):
train_segmenter(cfg)
print(cfg)
if isinstance(cfg, RstPointerParserTrainArgs):
train_parser(cfg)
print(cfg)