PointerAtten

class PointerAtten(atten_model: str, hidden_size: int)[source]

Pointer attention model to be used in RST parser network model.

forward(encoder_outputs: torch.Tensor, curr_decoder_outputs: torch.Tensor)Tuple[torch.Tensor, torch.Tensor][source]

Forward pass for Pointer Attention model.

Parameters
  • encoder_outputs (torch.Tensor) – output tensor from encoder RNN model.

  • curr_decoder_outputs (torch.Tensor) – output tensor from decoder RNN model.

Returns

return softmax and log softmax tensors of pointer attention model.

Return type

Tuple[torch.Tensor, torch.Tensor]