RstPointerSegmenterModel¶
-
class
RstPointerSegmenterModel
(config: sgnlp.models.rst_pointer.config.RstPointerSegmenterConfig)[source]¶ This model performs discourse segmentation.
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
- Parameters
config (
RstPointerSegmenterConfig
) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Use thefrom_pretrained
method to load the model weights.
Example:
from sgnlp.models.rst_pointer import RstPointerSegmenterConfig, RstPointerSegmenterModel # Method 1: Loading a default model segmenter_config = RstPointerSegmenterConfig() segmenter = RstPointerSegmenterModel(segmenter_config) # Method 2: Loading from pretrained segmenter_config = RstPointerSegmenterConfig.from_pretrained( 'https://storage.googleapis.com/sgnlp/models/rst_pointer/segmenter/config.json') segmenter = RstPointerSegmenterModel.from_pretrained( 'https://storage.googleapis.com/sgnlp/models/rst_pointer/segmenter/pytorch_model.bin', config=segmenter_config)
-
decoder
(h_n, h_end, batch_x_lens, batch_y=None)[source]¶ - Parameters
h_n – all hidden states
h_end – final hidden state
batch_x_lens – lengths of x (i.e. number of tokens)
batch_y – optional. provide to get loss metric.
- Returns
batch_start_boundaries: array of start tokens for each predicted edu batch_end_boundaries: array of end tokens for each predicted edu batch_align_matrix: - batch_loss: optional metric loss. calculated if batch_y_index is provided.
- Return type
A tuple containing the following values
-
forward
(tokenized_sentence_ids, sentence_lens, labels=None)[source]¶ - Parameters
tokenized_sentence_ids – Token IDs.
sentence_lens – Sentence lengths.
labels – Optional. Provide if loss is needed.
- Returns
output (
RstPointerSegmenterModelOutput
)