from math import inf
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .config import CsgConfig
from .modules.conv_decoder import ConvDecoder
from .modules.conv_encoder import ConvEncoder
from .utils import Buffer, Beam
[docs]class CsgPreTrainedModel(PreTrainedModel):
config_class = CsgConfig
base_model_prefix = "csg"
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
[docs]class CsgModel(CsgPreTrainedModel):
def __init__(self, config: CsgConfig):
super().__init__(config)
self.config = config
self.encoder = ConvEncoder(
num_embeddings=config.source_vocab_size,
embedding_dim=config.embedding_dim,
max_seq_len=config.src_max_seq_len,
padding_idx=config.padding_idx,
token_dropout=config.dropout,
hidden_dim=config.hidden_dim,
kernel_size=config.kernel_size,
dropout=config.dropout,
num_conv_layers=config.num_encoders,
)
self.auxencoder = ConvEncoder(
num_embeddings=config.source_vocab_size,
embedding_dim=config.embedding_dim,
max_seq_len=config.ctx_max_seq_len,
padding_idx=config.padding_idx,
token_dropout=config.dropout,
hidden_dim=config.hidden_dim,
kernel_size=config.kernel_size,
dropout=config.dropout,
num_conv_layers=config.num_aux_encoders,
)
self.decoder = ConvDecoder(
num_embeddings=config.target_vocab_size,
embedding_dim=config.embedding_dim,
max_seq_len=config.trg_max_seq_len,
padding_idx=config.padding_idx,
token_dropout=config.dropout,
hidden_dim=config.hidden_dim,
kernel_size=config.kernel_size,
dropout=config.dropout,
num_conv_layers=config.num_decoders,
)
[docs] def forward(self, *args, **kwargs):
assert False, "Please use the decode method to get model predictions."
def decode(self, batch_source_ids, batch_context_ids):
batch_output = []
for text_source_ids, text_context_ids in zip(
batch_source_ids, batch_context_ids
):
text_output = []
for sentence_source_ids, sentence_context_ids in zip(
text_source_ids, text_context_ids
):
encoder_out_dict = self.encoder(sentence_source_ids.reshape(1, -1))
auxencoder_out_dict = self.auxencoder(
sentence_context_ids.reshape(1, -1)
)
best_sentence_indices = self._decode_one(
encoder_out_dict=encoder_out_dict,
auxencoder_out_dict=auxencoder_out_dict,
beam_size=self.config.beam_size,
max_len=len(sentence_source_ids)
+ 5, # +5 is arbitrary. Normally we'd add 1 for the EOS token
)
text_output += [best_sentence_indices[1:-1]]
batch_output.append(text_output)
return batch_output
def _decode_one(
self,
encoder_out_dict,
auxencoder_out_dict,
beam_size,
max_len=None,
):
max_len = max_len if max_len is not None else self.config.trg_max_seq_len
incremental_state_buffer = Buffer(max_len=self.config.num_decoders)
finalised_cands_beam = Beam(beam_size=beam_size)
num_cands = beam_size * 2
prev_output_tokens = torch.tile(
torch.LongTensor([self.config.eos_idx]), (beam_size, 1)
)
for step in range(max_len):
decoder_output = self.decoder(
prev_output_tokens=prev_output_tokens,
encoder_out_dict=encoder_out_dict,
auxencoder_out_dict=auxencoder_out_dict,
incremental_state=incremental_state_buffer,
)
if step == 0:
seq_values, topk_cand = torch.topk(
torch.log_softmax(decoder_output[:, -1:, :], dim=2)[0, :, :].data,
num_cands,
)
# Check which candidates have been finalised by checking whether
# the EOS token has been generated
unfinalised_indices = topk_cand != self.config.eos_idx
finalised_indices = topk_cand == self.config.eos_idx
# If there are any finalised candidates, add them to the beam
if torch.any(finalised_indices):
finalised_cand_lst = topk_cand[finalised_indices].reshape(-1)
finalised_cand_lst = torch.cat(
(
torch.tile(
torch.LongTensor([self.config.eos_idx]),
(finalised_cand_lst.shape[0], 1),
),
finalised_cand_lst,
),
dim=1,
).tolist()
# Normalise scores based on the number of tokens generated
finalised_scores_lst = seq_values[finalised_indices] / (step + 1)
finalised_score_lst = finalised_scores_lst.reshape(-1).tolist()
# Add the finalised candidates
finalised_cands_beam.add_elements(
finalised_score_lst, finalised_cand_lst
)
# We will continue the search on beam_size unfinalised candidates
# Ie, we continue with the topk unfinalised candidates
unfinalised_cand_lst = topk_cand[unfinalised_indices][:beam_size]
seq_values = seq_values[unfinalised_indices][:beam_size]
prev_output_tokens = torch.cat(
(
prev_output_tokens,
unfinalised_cand_lst.reshape(-1, 1),
),
dim=1,
)
else:
log_proba = torch.log_softmax(decoder_output, dim=2)
log_proba[:, :, self.config.padding_idx] = -inf
seq_values = (
log_proba.reshape(beam_size, -1) + seq_values.reshape(beam_size, -1)
).reshape(-1)
seq_values, topk_cand = torch.topk(seq_values.data, num_cands)
# Compute the new token indices first
new_token_vocab_indices = topk_cand % self.config.target_vocab_size
prev_sequence_indices = torch.div(
topk_cand,
self.config.target_vocab_size,
rounding_mode="floor",
)
# Check which candidates have been finalised by checking whether
# the EOS token has been generated
unfinalised_indices = new_token_vocab_indices != self.config.eos_idx
finalised_indices = new_token_vocab_indices == self.config.eos_idx
# If there are any finalised candidates, add them to the beam
if torch.any(finalised_indices):
finalised_new_token_vocab_indices = new_token_vocab_indices[
finalised_indices
].reshape(-1, 1)
finalised_prev_sequence_indices = prev_sequence_indices[
finalised_indices
]
finalised_cand_lst = torch.cat(
(
prev_output_tokens[finalised_prev_sequence_indices],
finalised_new_token_vocab_indices,
),
dim=1,
).tolist()
# Normalise scores based on the number of tokens generated
finalised_scores_lst = seq_values[finalised_indices] / (step + 1)
finalised_score_lst = finalised_scores_lst.reshape(-1).tolist()
# Add the elements to the beam
finalised_cands_beam.add_elements(
finalised_score_lst, finalised_cand_lst
)
# Retain only the top beam_size worth of unfinalised hypotheses
unfinalised_new_token_vocab_indices = new_token_vocab_indices[
unfinalised_indices
][:beam_size].reshape(beam_size, -1)
unfinalised_prev_sequence_indices = prev_sequence_indices[
unfinalised_indices
][:beam_size]
prev_output_tokens = torch.cat(
(
prev_output_tokens[unfinalised_prev_sequence_indices],
unfinalised_new_token_vocab_indices,
),
dim=1,
)
seq_values = seq_values[unfinalised_indices][:beam_size]
for idx in range(self.config.num_decoders):
temp = incremental_state_buffer.get_first_element()[
unfinalised_prev_sequence_indices, :, :
]
incremental_state_buffer.add_element(temp)
# If the score of the unfinalised outputs are smaller than the score
# of the worst finalised candidate, then it can't get any better and we
# stop the generation
if (
(seq_values / max_len)[0] < finalised_cands_beam.get_lowest_score()
) and len(finalised_cands_beam.get_elements()) != 0:
break
return finalised_cands_beam.get_best_element()["indices"]