Skip to content

Commit

Permalink
Support edge_attr & self_loops for SpTensor in GAT
Browse files Browse the repository at this point in the history
  • Loading branch information
dongkwan-kim committed Sep 4, 2021
1 parent 6c6747b commit d3213b5
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 21 deletions.
49 changes: 34 additions & 15 deletions test/nn/conv/test_gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,33 +81,52 @@ def test_gat_conv():
assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-6)
assert torch.allclose(jit((x1, None), adj.t()), out2, atol=1e-6)

# Test handling edge weight or multi-dimensional features.
edge_weight = torch.ones((edge_index.size(1)), dtype=torch.float)
edge_attr = torch.ones((edge_index.size(1), 7), dtype=torch.float)

def test_gat_conv_with_edge_attr():
x = torch.randn(4, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 1, 1]])
row, col = edge_index

edge_weight = torch.randn(row.size(0))
adj_1 = SparseTensor(row=row, col=col, value=edge_weight,
sparse_sizes=(4, 4))
conv = GATConv(8, 32, heads=2, edge_dim=1, edge_attr_for_self_loops='fill',
edge_attr_fill_value=0.5)
out = conv(x1, edge_index, edge_attr=edge_weight)
out = conv(x, edge_index, edge_attr=edge_weight)
assert out.size() == (4, 64)
assert conv(x1, edge_index, size=(4, 4),
assert conv(x, edge_index, size=(4, 4),
edge_attr=edge_weight).tolist() == out.tolist()
assert torch.allclose(conv(x, adj_1.t()), out, atol=1e-6)
assert torch.allclose(conv(x, adj_1.t(), edge_attr=edge_weight), out,
atol=1e-6)

conv = GATConv(8, 32, heads=2, edge_dim=7, edge_attr_for_self_loops='fill',
edge_attr_fill_value=0.5)
out = conv(x1, edge_index, edge_attr=edge_attr)
edge_attr = torch.randn((row.size(0), 7), dtype=torch.float)
adj_2 = SparseTensor(row=row, col=col, value=edge_attr,
sparse_sizes=(4, 4))
conv = GATConv(8, 32, heads=2, edge_dim=7, edge_attr_for_self_loops='mean')
out = conv(x, edge_index, edge_attr=edge_attr)
assert out.size() == (4, 64)
assert conv(x1, edge_index, size=(4, 4),
assert conv(x, edge_index, size=(4, 4),
edge_attr=edge_attr).tolist() == out.tolist()

edge_attr = torch.ones((edge_index.size(1), 7), dtype=torch.float)
conv = GATConv(8, 32, heads=2, edge_dim=7, edge_attr_for_self_loops='mean')
out = conv(x1, edge_index, edge_attr=edge_attr)
conv = GATConv(8, 32, heads=2, edge_dim=7, edge_attr_for_self_loops='fill',
edge_attr_fill_value=0.5)
out = conv(x, edge_index, edge_attr=edge_attr)
assert out.size() == (4, 64)
assert conv(x1, edge_index, size=(4, 4),
assert conv(x, edge_index, size=(4, 4),
edge_attr=edge_attr).tolist() == out.tolist()
assert torch.allclose(conv(x, adj_2.t()), out, atol=1e-6)
assert torch.allclose(conv(x, adj_2.t(), edge_attr=edge_attr), out,
atol=1e-6)

t = '(Tensor, Tensor, Size, OptTensor, NoneType) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, edge_index, edge_attr=edge_attr).tolist() == out.tolist()
assert jit(x1, edge_index, size=(4, 4),
assert jit(x, edge_index, edge_attr=edge_attr).tolist() == out.tolist()
assert jit(x, edge_index, size=(4, 4),
edge_attr=edge_attr).tolist() == out.tolist()

t = '(Tensor, SparseTensor, Size, OptTensor, NoneType) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x, adj_2.t()), out, atol=1e-6)
assert torch.allclose(jit(x, adj_2.t(), edge_attr=edge_attr), out,
atol=1e-6)
22 changes: 16 additions & 6 deletions torch_geometric/nn/conv/gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Parameter
from torch_sparse import SparseTensor, set_diag
from torch_sparse import SparseTensor, set_diag, fill_diag
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
Expand Down Expand Up @@ -206,15 +206,25 @@ def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
fill_value=self.edge_attr_fill_value,
num_nodes=num_nodes)
elif isinstance(edge_index, SparseTensor):
assert edge_attr is None, \
("Using `edge_attr` not supported for "
"`edge_index` in SparseTensor form.")
edge_index = set_diag(edge_index)
if self.edge_dim is None:
edge_index = set_diag(edge_index)
else:
assert self.edge_attr_for_self_loops == 'fill', \
('Using `edge_attr` and `add_self_loops` '
'simultaneously with "{}" method is not '
'supported for `edge_index` in a SparseTensor '
'form.'.format(self.edge_attr_for_self_loops))
if edge_attr is not None:
edge_index.set_value(edge_attr, layout='coo')
edge_index = fill_diag(
edge_index, fill_value=self.edge_attr_fill_value)
_, _, edge_attr = edge_index.coo()

# If edge features are given, compute attention using them.
if edge_attr is not None:
assert self.lin_edge is not None
edge_attr = edge_attr.view(edge_index.size(1), -1)
if edge_attr.dim() == 1:
edge_attr = edge_attr.unsqueeze(-1)
edge_attr = self.lin_edge(edge_attr).view(-1, H, C)
alpha_edge = (edge_attr * self.att_edge).sum(dim=-1)
else:
Expand Down

0 comments on commit d3213b5

Please sign in to comment.