Skip to content

Commit

Permalink
fix softmax direction for gat/agnn
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Apr 19, 2019
1 parent 8be38c2 commit 6fe33a2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
31 changes: 15 additions & 16 deletions torch_geometric/nn/conv/agnn_conv.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import torch
from torch.nn import Parameter
import torch.nn.functional as F
from torch_sparse import spmm
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax


class AGNNConv(torch.nn.Module):
class AGNNConv(MessagePassing):
r"""Graph attentional propagation layer from the
`"Attention-based Graph Neural Network for Semi-Supervised Learning"
<https://arxiv.org/abs/1803.03735>`_ paper
Expand All @@ -28,7 +28,7 @@ class AGNNConv(torch.nn.Module):
"""

def __init__(self, requires_grad=True):
super(AGNNConv, self).__init__()
super(AGNNConv, self).__init__('add')

self.requires_grad = requires_grad

Expand All @@ -43,25 +43,24 @@ def reset_parameters(self):
if self.requires_grad:
self.beta.data.fill_(1)

def propagation_matrix(self, x, edge_index):
def forward(self, x, edge_index):
""""""
edge_index, _ = remove_self_loops(edge_index)
edge_index = add_self_loops(edge_index, x.size(0))
row, col = edge_index

beta = self.beta if self.requires_grad else self._buffers['beta']
x_norm = F.normalize(x, p=2, dim=-1)

x = x.unsqueeze(-1) if x.dim() == 1 else x
x = F.normalize(x, p=2, dim=-1)
edge_weight = beta * (x[row] * x[col]).sum(dim=-1)
edge_weight = softmax(edge_weight, row, num_nodes=x.size(0))
return self.propagate(
edge_index, x=x, x_norm=x_norm, num_nodes=x.size(0))

return edge_index, edge_weight
def message(self, edge_index, x_j, x_norm_i, x_norm_j, num_nodes):
# Compute attention coefficients.
beta = self.beta if self.requires_grad else self._buffers['beta']
alpha = beta * (x_norm_i * x_norm_j).sum(dim=-1)
i = 0 if self.flow == 'target_to_source' else 1
alpha = softmax(alpha, edge_index[i], num_nodes)

def forward(self, x, edge_index):
""""""
edge_index, edge_weight = self.propagation_matrix(x, edge_index)
out = spmm(edge_index, edge_weight, x.size(0), x)
return out
return x_j * alpha.view(-1, 1)

def __repr__(self):
return '{}()'.format(self.__class__.__name__)
5 changes: 3 additions & 2 deletions torch_geometric/nn/conv/gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,12 @@ def forward(self, x, edge_index):
x = torch.mm(x, self.weight).view(-1, self.heads, self.out_channels)
return self.propagate(edge_index, x=x, num_nodes=x.size(0))

def message(self, x_i, x_j, edge_index, num_nodes):
def message(self, edge_index, x_i, x_j, num_nodes):
# Compute attention coefficients.
alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(alpha, edge_index[0], num_nodes)
i = 0 if self.flow == 'target_to_source' else 1
alpha = softmax(alpha, edge_index[i], num_nodes)

# Sample attention coefficients stochastically.
if self.training and self.dropout > 0:
Expand Down

0 comments on commit 6fe33a2

Please sign in to comment.