-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to obtain gradients for input node features and graph structure? #3181
Comments
@monk1337 It is not necessary to pas only import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(dataset.num_node_features, 16)
self.conv2 = GCNConv(16, dataset.num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index # you can create a custom dataset object and then use it here.
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1) As an example, this is how we can create a custom import networkx as nx
import torch_geometric as pyg
import torch
class MyKarateClub(pyg.data.InMemoryDataset):
def __init__(self, dim=128, transform=None):
# dim: The dimension of node embeddings
super(MyKarateClub, self).__init__(transform, None, None)
G = nx.karate_club().to_undirected()
adj = torch.from_numpy(nx.to_numpy_matrix(G)) # adjacency matrix
row = torch.from_numpy(adj.row.astype(np.int64)).to(torch.long)
col = torch.from_numpy(adj.col.astype(np.int64)).to(torch.long)
edge_index = torch.stack([row, col], dim=0) # edge index
x = torch.rand((G.number_of_nodes(), dim), dtype=torch.float) # initializing node embeddings with random values
data = pyg.data.Data(x=x, edge_index=edge_index, adj=adj)
self.data, self.slices = self.collate([data])
def __repr__(self):
return '{}{}'.format(self.__class__.__name__) The above code is an adaption from the souce code from the PyG repository. You can now do the following: dataset = MyKarateClub()
data = dataset[0] # as there is only a single graph
out = model(data) # here model is your GNN where the `forward` method takes torch_geometric.data.Data object as input This will automatically set |
If you want to obtain gradients for input node features and graph structure, you have to do this via an additional from torch_geometric.nn.conv import GCNConv
conv = GCNConv(16, 32)
x = torch.nn.Parameter(x)
edge_index = A.nonzero(as_tuple=False).t()
edge_weight = torch.nn.Parameter(A[edge_index[0], edge_index[1]])
out = conv(x, edge_index, edge_weight) |
@rusty1s What is |
|
For example
How I can obtain gradients for input node features and graph structure in this SparseTensor case ? |
You can wrap import torch
from torch_sparse import SparseTensor
from torch_geometric.nn import GCN
adj = torch.rand([2, 2])
feat = torch.randn(2, 5)
feat = torch.nn.Parameter(feat)
edge_index = adj.nonzero(as_tuple=False).t()
edge_weight = torch.nn.Parameter(adj[edge_index[0], edge_index[1]])
adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_weight,
sparse_sizes=(2, 2))
model = GCN(5, 16, num_layers=2)
out = model(feat, adj.t())
out.mean().backward()
print(feat.grad)
print(edge_weight.grad) |
Hello, My GCN layer looks like this :
For example if the Sample data is
The input to GCN layer are wrapped in Pytorch Parameter with
requires_grad = True
such asWhile in PyG networks the input is either
(x,edge_index) or (x, sparse_tensor.t())
How I can pass the input as Parameter like this to PyG network withrequires_grad=True
?The text was updated successfully, but these errors were encountered: