import torch.nn as nn
import torch
import torch.nn.functional as F
from ..utils import get_default_device
[docs]class GraphConvLayer(nn.Module):
"""A GCN module operated on dependency graphs."""
def __init__(self, mem_dim, layers, dropout, device=None, self_loop=False):
super(GraphConvLayer, self).__init__()
self.device = device if device else get_default_device()
self.mem_dim = mem_dim
self.layers = layers
self.head_dim = self.mem_dim // self.layers
self.gcn_drop = dropout
# linear transformation
self.linear_output = nn.Linear(self.mem_dim, self.mem_dim)
# dcgcn block
self.weight_list = nn.ModuleList()
for i in range(self.layers):
self.weight_list.append(
nn.Linear((self.mem_dim + self.head_dim * i), self.head_dim)
)
self.weight_list = self.weight_list.to(device=self.device)
self.linear_output = self.linear_output.to(device=self.device)
self.self_loop = self_loop
[docs] def forward(self, adj, gcn_inputs):
# gcn layer
denom = adj.sum(2).unsqueeze(2) + 1
outputs = gcn_inputs
cache_list = [outputs]
output_list = []
for l in range(self.layers):
Ax = adj.bmm(outputs)
AxW = self.weight_list[l](Ax)
if self.self_loop:
AxW = AxW + self.weight_list[l](outputs) # self loop
else:
AxW = AxW
AxW = AxW / denom
gAxW = F.relu(AxW)
cache_list.append(gAxW)
outputs = torch.cat(cache_list, dim=2)
output_list.append(self.gcn_drop(gAxW))
gcn_outputs = torch.cat(output_list, dim=2)
gcn_outputs = gcn_outputs + gcn_inputs
out = self.linear_output(gcn_outputs)
return out