Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GATConv saves weights compatibility? #1755

Open
jby1993 opened this issue Oct 24, 2020 · 12 comments
Open

GATConv saves weights compatibility? #1755

jby1993 opened this issue Oct 24, 2020 · 12 comments

Comments

@jby1993
Copy link

jby1993 commented Oct 24, 2020

❓ Questions & Help

I used pytorch_geometric 1.3.2 to trained a model, which used the GATConv module. Recently, I use the latest pyg and found the GATConv has been changed and the old weights file cann't be loaded. I changed the weights file, but can not get same results with old version. Can somebody tell me how to modify the weights file to support latest pyg?

@rusty1s
Copy link
Member

rusty1s commented Oct 25, 2020

You basically need to ensure that:

conv.lin_l.weight == conv_old.weight
conv.att_l.weight == conv_old.att[:, :, :out_channels]
conv.att_r.weight == conv_old.att[:, :, out_channels:]

@jby1993
Copy link
Author

jby1993 commented Oct 27, 2020

OK. This is my conversion:
conv.lin_l.weight = conv_old.weight.t()
conv.lin_r.weight = conv_old.weight.t()
conv.att_l.weight = conv_old.att[:, :, :out_channels]
conv.att_r.weight = conv_old.att[:, :, out_channels:]
However, the computation is not totally consistent. I also tried to convert to 1.5.0 version with similar method. The computation is same. What's the problem?

@rusty1s
Copy link
Member

rusty1s commented Oct 27, 2020

I checked that. You also need to transpose the weight matrix:

conv2.lin_l.weight.data = conv1.weight.data.t()
conv2.att_l.data = conv1.att[:, :, out_channels:].data
conv2.att_r.data = conv1.att[:, :, :out_channels].data

That yields an equal result for me.

@Zhiwei-Zhai
Copy link

Hi I got the same problem. Do you find a easy solution for it?
I updated torch-geometric to latest version, however, my model was trained with an earlier version.

@jby1993
Copy link
Author

jby1993 commented Oct 28, 2020

@chushan89, I use the conversion method provided by @rusty1s, it works for 1.5.0 version. When I use the method for 1.6.1, the converted model computation has some little deviation value. I guess the lin_r transform module introduced in 1.6.1 has some influence. For now I just set its weight same with lin_l.

@dharouni
Copy link

Hello @rusty1s ,
I have one question for the implementation of GATConv regarding this part in the forward() function:

if isinstance(x, Tensor):
    assert x.dim() == 2, 'Static graphs not supported in `GATConv`.'
    x_l = x_r = self.lin_l(x).view(-1, H, C)
    alpha_l = alpha_r = (x_l * self.att_l).sum(dim=-1)

Within this implementation, it seems to be that self.att_r is not used and both representations within the concatenation are multiplied with the same vector self.att_l with dimension (1, out_channels). In the paper, however, the vector has dimension (1, 2 * out_channels), so that both parts of the concatenation are multiplied with individually learnable parameters. Could you explain, if the implementation is indeed different from the paper?

Thank you very much!

@rusty1s
Copy link
Member

rusty1s commented Nov 12, 2020

Yeah, that is already fixed in master, see https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/nn/conv/gat_conv.py#L126-L127. Sorry for the inconveniences!

@eMarco
Copy link

eMarco commented Feb 24, 2021

Hi, I have the same issue.
I trained my GATConv models using version 1.1.2 but can't manage to get the same results even with version 1.2.0.
Any hint how to use these weight files with PyGeom 1.6.x (or even 1.3.x, I could just map the weights as suggested above)?
Unfortunately, forward porting the old class doesn't seem a viable option due to many changes in the MessagePassing class (e.g., TorchScript, interface, etc).

@rusty1s
Copy link
Member

rusty1s commented Feb 24, 2021

Yes, the above solution should work for the current GATConv implementation. Please let me know if this works for you.

@eMarco
Copy link

eMarco commented Feb 24, 2021

Thank you for your reply and for your work!

Unfortunately it is not working with v1.6.3.
I get similar values as I would get with v1.2.0, but they are not even close to the output of v1.1.2.

So far I've been trying to track the relevant changes that may break the compatibility with v1.2.0 onwards and the only change I see when diffing v1.1.2 and v1.2.0 is the computation of the alpha coefficient.

-    def message(self, x_i, x_j, edge_index, num_nodes):
+    def message(self, edge_index_i, x_i, x_j, num_nodes):
         # Compute attention coefficients.
         alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
         alpha = F.leaky_relu(alpha, self.negative_slope)
-        alpha = softmax(alpha, edge_index[0], num_nodes)
+        alpha = softmax(alpha, edge_index_i, num_nodes)

And yeah, forward porting the v1.1.2 GATConv to v1.2.0 works as expected, so I'm quite sure this may be the culprit.

Any other idea how to adapt the old weights?

@rusty1s
Copy link
Member

rusty1s commented Feb 25, 2021

You may try to swap the values of att_l and att_r.

@eMarco
Copy link

eMarco commented Feb 26, 2021

Thank you gain.

Unfortunately that didn't work either.

I ended up modifying the v1.6.3 GATConv code by "reverting" the change I posted in my previous comment (i.e., I'm now passing edge_index[0] instead of index to the softmax) and the results are (almost) identical with v1.1.2 .
The same is true if I pass edge_index_j instead. I'll dig a bit further in the MessagePassing code to understand what's going on and check if everything is in order for some more networks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants