ConvAttention

class ConvAttention(hidden_dim, token_embedding_dim, normalization_constant=0.5)[source]

Attention module to compare the encoder outputs with the target word embedding in the decoder.

forward(Y, T, E, ES, encoder_padding_mask=None)[source]
Ytorch Tensor

ConvGLU output of the [BOS] until the (n-1)th tokens. Shape of (batch size, sequence length, hidden dim).

Ttorch Tensor

Target token embedding of the [BOS] until the (n-1)th tokens. Shape of (batch size, sequence length, token embedding dim).

Etorch Tensor

Encoder output of all source/context tokens. Shape of (batch size, sequence length, token embedding dim).

EStorch Tensor

Elementwise sum of the token embeddings and encoder outputs of all source/context tokens. Shape of (batch size, sequence length, token embedding dim).