From 025b1cb0c94eeac768d6facbb942d84c223c0b19 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 20 Apr 2022 11:10:31 +0200 Subject: [PATCH] GATConv: require edge_dim to be set --- torch_geometric/nn/conv/gat_conv.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index 13b1af546b18..b5d0e18194da 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -268,10 +268,9 @@ def edge_update(self, alpha_j: Tensor, alpha_i: OptTensor, # we simply need to sum them up to "emulate" concatenation: alpha = alpha_j if alpha_i is None else alpha_j + alpha_i - if edge_attr is not None: + if edge_attr is not None and self.lin_edge is not None: if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) - assert self.lin_edge is not None edge_attr = self.lin_edge(edge_attr) edge_attr = edge_attr.view(-1, self.heads, self.out_channels) alpha_edge = (edge_attr * self.att_edge).sum(dim=-1)