RstPointerParserModel¶
-
class
RstPointerParserModel
(config: sgnlp.models.rst_pointer.config.RstPointerParserConfig)[source]¶ This model performs discourse parsing.
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
- Parameters
config (
RstPointerParserConfig
) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Use thefrom_pretrained
method to load the model weights.
Example:
from sgnlp.models.rst_pointer import RstPointerParserConfig, RstPointerParserModel # Method 1: Loading a default model parser_config = RstPointerParserConfig() parser = RstPointerParserModel(parser_config) # Method 2: Loading from pretrained parser_config = RstPointerParserConfig.from_pretrained( 'https://storage.googleapis.com/sgnlp/models/rst_pointer/parser/config.json') parser = RstPointerParserModel.from_pretrained( 'https://storage.googleapis.com/sgnlp/models/rst_pointer/parser/pytorch_model.bin', config=parser_config)
-
forward
(input_sentence_ids, edu_breaks, sentence_lengths, label_index=None, parsing_index=None, generate_splits=True)[source]¶ - Parameters
input_sentence_ids – Input sentence IDs.
edu_breaks – Token positions of edu breaks.
sentence_lengths – Lengths of sentences.
label_index – Label IDs. Needed only if loss needs to be computed.
parsing_index – Parsing IDs. Needed only if loss needs to be computed.
generate_splits – Whether to return splits.
- Returns
output (
RstPointerParserModelOutput
)