Skip to content

Commit

Permalink
Support edge_attr in GAT when add_self_loops=False
Browse files Browse the repository at this point in the history
  • Loading branch information
dongkwan-kim committed Sep 4, 2021
1 parent d3213b5 commit 7f5cd3f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
18 changes: 18 additions & 0 deletions test/nn/conv/test_gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ def test_gat_conv_with_edge_attr():
assert torch.allclose(conv(x, adj_1.t(), edge_attr=edge_weight), out,
atol=1e-6)

conv = GATConv(8, 32, heads=2, edge_dim=1, add_self_loops=False)
out = conv(x, edge_index, edge_attr=edge_weight)
assert out.size() == (4, 64)
assert conv(x, edge_index, size=(4, 4),
edge_attr=edge_weight).tolist() == out.tolist()
assert torch.allclose(conv(x, adj_1.t()), out, atol=1e-6)
assert torch.allclose(conv(x, adj_1.t(), edge_attr=edge_weight), out,
atol=1e-6)

edge_attr = torch.randn((row.size(0), 7), dtype=torch.float)
adj_2 = SparseTensor(row=row, col=col, value=edge_attr,
sparse_sizes=(4, 4))
Expand All @@ -119,6 +128,15 @@ def test_gat_conv_with_edge_attr():
assert torch.allclose(conv(x, adj_2.t(), edge_attr=edge_attr), out,
atol=1e-6)

conv = GATConv(8, 32, heads=2, edge_dim=7, add_self_loops=False)
out = conv(x, edge_index, edge_attr=edge_attr)
assert out.size() == (4, 64)
assert conv(x, edge_index, size=(4, 4),
edge_attr=edge_attr).tolist() == out.tolist()
assert torch.allclose(conv(x, adj_2.t()), out, atol=1e-6)
assert torch.allclose(conv(x, adj_2.t(), edge_attr=edge_attr), out,
atol=1e-6)

t = '(Tensor, Tensor, Size, OptTensor, NoneType) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x, edge_index, edge_attr=edge_attr).tolist() == out.tolist()
Expand Down
5 changes: 5 additions & 0 deletions torch_geometric/nn/conv/gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
edge_index = fill_diag(
edge_index, fill_value=self.edge_attr_fill_value)
_, _, edge_attr = edge_index.coo()
else:
if isinstance(edge_index, SparseTensor):
if edge_attr is not None:
edge_index.set_value(edge_attr, layout='coo')
_, _, edge_attr = edge_index.coo()

# If edge features are given, compute attention using them.
if edge_attr is not None:
Expand Down

0 comments on commit 7f5cd3f

Please sign in to comment.