Skip to content

Commit

Permalink
Handle empty edge_index in add_self_loops
Browse files Browse the repository at this point in the history
  • Loading branch information
dongkwan-kim committed Sep 4, 2021
1 parent 7f5cd3f commit 7444821
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
16 changes: 16 additions & 0 deletions test/utils/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,22 @@ def test_add_self_loops():
assert out.tolist() == expected
assert edge_attr[3:, :].tolist() == [[1, 0, 1], [0, 1, 0]]

empty_edge_index = torch.tensor([]).view(2, 0)
out, _ = add_self_loops(empty_edge_index, num_nodes=1)
assert out.tolist() == [[0], [0]]

empty_edge_weight = torch.tensor([])
out, empty_edge_weight = add_self_loops(
empty_edge_index, empty_edge_weight, num_nodes=1)
assert out.tolist() == [[0], [0]]
assert empty_edge_weight.tolist() == [1]

empty_edge_attr = torch.tensor([]).view(0, 3)
out, empty_edge_attr = add_self_loops(
empty_edge_index, empty_edge_attr, num_nodes=1)
assert out.tolist() == [[0], [0]]
assert empty_edge_attr.tolist() == [[1, 1, 1]]


def test_add_remaining_self_loops():
edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]])
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/utils/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ def add_self_loops(edge_index, edge_attr: Optional[torch.Tensor] = None,
loop_attr = scatter(edge_attr, edge_index[0], dim=0, dim_size=N,
reduce=fill_or_reduce)
else:
num_features = edge_attr.numel() // E
num_features = 1 if edge_attr.dim() == 1 else edge_attr.size(-1)
loop_attr = edge_attr.new_full(
(N, num_features), fill_value).squeeze()
(N, num_features), fill_value).squeeze(-1)
edge_attr = torch.cat([edge_attr, loop_attr], dim=0)

edge_index = torch.cat([edge_index, loop_index], dim=1)
Expand Down

0 comments on commit 7444821

Please sign in to comment.