Source code for sgnlp.models.lsr.modules.reasoner

import torch.nn as nn
import torch
import torch.nn.functional as F

from .gcn import GraphConvLayer
from ..utils import get_default_device


[docs]class StructInduction(nn.Module): def __init__(self, sem_dim_size, sent_hiddent_size, bidirectional, device=None): super(StructInduction, self).__init__() self.device = device if device else get_default_device() self.bidirectional = bidirectional self.sem_dim_size = sem_dim_size self.str_dim_size = sent_hiddent_size - self.sem_dim_size self.tp_linear = nn.Linear(self.str_dim_size, self.str_dim_size, bias=True) torch.nn.init.xavier_uniform_(self.tp_linear.weight) nn.init.constant_(self.tp_linear.bias, 0) self.tc_linear = nn.Linear(self.str_dim_size, self.str_dim_size, bias=True) torch.nn.init.xavier_uniform_(self.tc_linear.weight) nn.init.constant_(self.tc_linear.bias, 0) self.fi_linear = nn.Linear(self.str_dim_size, 1, bias=False) torch.nn.init.xavier_uniform_(self.fi_linear.weight) self.bilinear = nn.Bilinear(self.str_dim_size, self.str_dim_size, 1, bias=False) torch.nn.init.xavier_uniform_(self.bilinear.weight) self.exparam = nn.Parameter(torch.Tensor(1, 1, self.sem_dim_size)) torch.nn.init.xavier_uniform_(self.exparam) self.fzlinear = nn.Linear( 3 * self.sem_dim_size, 2 * self.sem_dim_size, bias=True ) torch.nn.init.xavier_uniform_(self.fzlinear.weight) nn.init.constant_(self.fzlinear.bias, 0)
[docs] def forward(self, input_tensor): # batch*sent * token * hidden batch_size, token_size, dim_size = input_tensor.size() """STEP1: Calculating Attention Matrix""" if self.bidirectional: input_tensor = input_tensor.view(batch_size, token_size, 2, dim_size // 2) sem_v = torch.cat( ( input_tensor[:, :, 0, : self.sem_dim_size // 2], input_tensor[:, :, 1, : self.sem_dim_size // 2], ), 2, ) str_v = torch.cat( ( input_tensor[:, :, 0, self.sem_dim_size // 2 :], input_tensor[:, :, 1, self.sem_dim_size // 2 :], ), 2, ) else: sem_v = input_tensor[:, :, : self.sem_dim_size] str_v = input_tensor[:, :, self.sem_dim_size :] tp = torch.tanh(self.tp_linear(str_v)) # b*s, token, h1 tc = torch.tanh(self.tc_linear(str_v)) # b*s, token, h1 tp = ( tp.unsqueeze(2) .expand(tp.size(0), tp.size(1), tp.size(1), tp.size(2)) .contiguous() ) tc = ( tc.unsqueeze(2) .expand(tc.size(0), tc.size(1), tc.size(1), tc.size(2)) .contiguous() ) f_ij = self.bilinear(tp, tc).squeeze(dim=-1) # b*s, token , token f_i = torch.exp(self.fi_linear(str_v)).squeeze(dim=-1) # b*s, token mask = torch.ones(f_ij.size(1), f_ij.size(1)) - torch.eye( f_ij.size(1), f_ij.size(1) ) mask = ( mask.unsqueeze(0) .expand(f_ij.size(0), mask.size(0), mask.size(1)) .to(device=self.device) ) A_ij = torch.exp(f_ij) * mask # STEP: Include Latent Structure tmp = torch.sum(A_ij, dim=1) # nan: dimension res = torch.zeros(batch_size, token_size, token_size).to(device=self.device) res.diagonal(dim1=1, dim2=2).copy_(tmp) # Assign tmp to diagonals L_ij = -A_ij + res # A_ij has 0s as diagonals L_ij_bar = L_ij L_ij_bar[:, 0, :] = f_i LLinv = torch.inverse(L_ij_bar) d0 = f_i * LLinv[:, :, 0] LLinv_diag = torch.diagonal(LLinv, dim1=-2, dim2=-1).unsqueeze(2) tmp1 = (A_ij.transpose(1, 2) * LLinv_diag).transpose(1, 2) tmp2 = A_ij * LLinv.transpose(1, 2) temp11 = torch.zeros(batch_size, token_size, 1) temp21 = torch.zeros(batch_size, 1, token_size) temp12 = torch.ones(batch_size, token_size, token_size - 1) temp22 = torch.ones(batch_size, token_size - 1, token_size) mask1 = torch.cat([temp11, temp12], 2).to(device=self.device) mask2 = torch.cat([temp21, temp22], 1).to(device=self.device) dx = mask1 * tmp1 - mask2 * tmp2 d = torch.cat([d0.unsqueeze(1), dx], dim=1) df = d.transpose(1, 2) ssr = torch.cat([self.exparam.repeat(batch_size, 1, 1), sem_v], 1) pinp = torch.bmm(df, ssr) cinp = torch.bmm(dx, sem_v) finp = torch.cat([sem_v, pinp, cinp], dim=2) output = F.relu(self.fzlinear(finp)) return output, df
[docs]class DynamicReasoner(nn.Module): def __init__(self, hidden_size, gcn_layer, dropout_gcn): super(DynamicReasoner, self).__init__() self.hidden_size = hidden_size self.gcn_layer = gcn_layer self.dropout_gcn = dropout_gcn self.struc_att = StructInduction(hidden_size // 2, hidden_size, True) self.gcn = GraphConvLayer( hidden_size, self.gcn_layer, self.dropout_gcn, self_loop=True )
[docs] def forward(self, input_tensor): # Structure Induction _, att = self.struc_att(input_tensor) # Perform reasoning output = self.gcn(att[:, :, 1:], input_tensor) return output