RstPointerParserModel

class RstPointerParserModel(config: sgnlp.models.rst_pointer.config.RstPointerParserConfig)[source]

This model performs discourse parsing.

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

Parameters

config (RstPointerParserConfig) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Use the from_pretrained method to load the model weights.

Example:

from sgnlp.models.rst_pointer import RstPointerParserConfig, RstPointerParserModel

# Method 1: Loading a default model
parser_config = RstPointerParserConfig()
parser = RstPointerParserModel(parser_config)

# Method 2: Loading from pretrained
parser_config = RstPointerParserConfig.from_pretrained(
    'https://storage.googleapis.com/sgnlp/models/rst_pointer/parser/config.json')
parser = RstPointerParserModel.from_pretrained(
    'https://storage.googleapis.com/sgnlp/models/rst_pointer/parser/pytorch_model.bin',
    config=parser_config)
forward(input_sentence_ids, edu_breaks, sentence_lengths, label_index=None, parsing_index=None, generate_splits=True)[source]
Parameters
  • input_sentence_ids – Input sentence IDs.

  • edu_breaks – Token positions of edu breaks.

  • sentence_lengths – Lengths of sentences.

  • label_index – Label IDs. Needed only if loss needs to be computed.

  • parsing_index – Parsing IDs. Needed only if loss needs to be computed.

  • generate_splits – Whether to return splits.

Returns

output (RstPointerParserModelOutput)