RstPointerSegmenterModel

class RstPointerSegmenterModel(config: sgnlp.models.rst_pointer.config.RstPointerSegmenterConfig)[source]

This model performs discourse segmentation.

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

Parameters

config (RstPointerSegmenterConfig) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Use the from_pretrained method to load the model weights.

Example:

from sgnlp.models.rst_pointer import RstPointerSegmenterConfig, RstPointerSegmenterModel

# Method 1: Loading a default model
segmenter_config = RstPointerSegmenterConfig()
segmenter = RstPointerSegmenterModel(segmenter_config)

# Method 2: Loading from pretrained
segmenter_config = RstPointerSegmenterConfig.from_pretrained(
    'https://storage.googleapis.com/sgnlp/models/rst_pointer/segmenter/config.json')
segmenter = RstPointerSegmenterModel.from_pretrained(
    'https://storage.googleapis.com/sgnlp/models/rst_pointer/segmenter/pytorch_model.bin',
    config=segmenter_config)
decoder(h_n, h_end, batch_x_lens, batch_y=None)[source]
Parameters
  • h_n – all hidden states

  • h_end – final hidden state

  • batch_x_lens – lengths of x (i.e. number of tokens)

  • batch_y – optional. provide to get loss metric.

Returns

batch_start_boundaries: array of start tokens for each predicted edu batch_end_boundaries: array of end tokens for each predicted edu batch_align_matrix: - batch_loss: optional metric loss. calculated if batch_y_index is provided.

Return type

A tuple containing the following values

forward(tokenized_sentence_ids, sentence_lens, labels=None)[source]
Parameters
  • tokenized_sentence_ids – Token IDs.

  • sentence_lens – Sentence lengths.

  • labels – Optional. Provide if loss is needed.

Returns

output (RstPointerSegmenterModelOutput)