Skip to content

Commit

Permalink
GATConv: require edge_dim to be set
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Apr 20, 2022
1 parent 5fdeae5 commit 025b1cb
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions torch_geometric/nn/conv/gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,9 @@ def edge_update(self, alpha_j: Tensor, alpha_i: OptTensor,
# we simply need to sum them up to "emulate" concatenation:
alpha = alpha_j if alpha_i is None else alpha_j + alpha_i

if edge_attr is not None:
if edge_attr is not None and self.lin_edge is not None:
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
assert self.lin_edge is not None
edge_attr = self.lin_edge(edge_attr)
edge_attr = edge_attr.view(-1, self.heads, self.out_channels)
alpha_edge = (edge_attr * self.att_edge).sum(dim=-1)
Expand Down

0 comments on commit 025b1cb

Please sign in to comment.