Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to obtain gradients for input node features and graph structure? #3181

Closed
monk1337 opened this issue Sep 20, 2021 · 7 comments
Closed

How to obtain gradients for input node features and graph structure? #3181

monk1337 opened this issue Sep 20, 2021 · 7 comments

Comments

@monk1337
Copy link
Contributor

monk1337 commented Sep 20, 2021

Hello, My GCN layer looks like this :

import torch
import numpy as np
import math

class GraphConvolution(torch.nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=False):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = torch.nn.Parameter(torch.Tensor(in_features, out_features))
        if bias:
            self.bias = torch.nn.Parameter(torch.Tensor(1, 1, out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.matmul(input, self.weight)
        output  = torch.matmul(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

For example if the Sample data is

feature_matrix = np.random.uniform(-1,1,[5,10])
adj                    = np.random.randint(0,2,[5,5])

The input to GCN layer are wrapped in Pytorch Parameter with requires_grad = True such as

# where feature matrix is    [ nodes x features] numpy array
Graph_features  = torch.nn.Parameter(torch.from_numpy(feature_matrix).float(),requires_grad=True)

# where adj_d is adj file of [ nodes x nodes ] numpy array
A    = torch.nn.Parameter(torch.from_numpy(adj).float(),requires_grad=True)
gc1  = GraphConvolution(in_features = 10,  out_features = 10)
x = gc1(Graph_features, A)

While in PyG networks the input is either (x,edge_index) or (x, sparse_tensor.t()) How I can pass the input as Parameter like this to PyG network with requires_grad=True?

# where feature matrix is    [ nodes x features] numpy array
Graph_features  = torch.nn.Parameter(torch.from_numpy(feature_matrix).float(),requires_grad=True)

# where adj_d is adj file of [ nodes x nodes ] numpy array
A    = torch.nn.Parameter(torch.from_numpy(adj).float(),requires_grad=True)
@ipsitmantri
Copy link

@monk1337 It is not necessary to pas only (x, edge_index) or (x, sparse_tensor.t()). You can pass a torch_geometric.data.Data object to the forward method and then access whatever tensor/matrix that you want using the object. An example code is given here in the documentation:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index  # you can create a custom dataset object and then use it here.

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

As an example, this is how we can create a custom KarateClub graph:

import networkx as nx
import torch_geometric as pyg
import torch

class MyKarateClub(pyg.data.InMemoryDataset):
    def __init__(self, dim=128, transform=None):
        # dim: The dimension of node embeddings
        super(MyKarateClub, self).__init__(transform, None, None)
        G = nx.karate_club().to_undirected()
        adj = torch.from_numpy(nx.to_numpy_matrix(G)) # adjacency matrix
        row = torch.from_numpy(adj.row.astype(np.int64)).to(torch.long)
        col = torch.from_numpy(adj.col.astype(np.int64)).to(torch.long)
        edge_index = torch.stack([row, col], dim=0) # edge index
        x = torch.rand((G.number_of_nodes(), dim), dtype=torch.float) # initializing node embeddings with random values 
        
        data = pyg.data.Data(x=x, edge_index=edge_index, adj=adj)
        self.data, self.slices = self.collate([data])
        
    def __repr__(self):
        return '{}{}'.format(self.__class__.__name__)

The above code is an adaption from the souce code from the PyG repository. You can now do the following:

dataset = MyKarateClub()
data = dataset[0] # as there is only a single graph
out = model(data) # here model is your GNN where the `forward` method takes torch_geometric.data.Data object as input

This will automatically set requires_grad=True for your adjacency matrix. Hope this helps!

@rusty1s
Copy link
Member

rusty1s commented Sep 21, 2021

If you want to obtain gradients for input node features and graph structure, you have to do this via an additional edge_weight vector. Note that you cannot simply obtain gradients from edge_index alone, as this is not a floating -point tensor.

from torch_geometric.nn.conv import GCNConv

conv = GCNConv(16, 32)

x = torch.nn.Parameter(x)

edge_index = A.nonzero(as_tuple=False).t()
edge_weight = torch.nn.Parameter(A[edge_index[0], edge_index[1]])

out = conv(x, edge_index, edge_weight)

@monk1337
Copy link
Contributor Author

@rusty1s What is A here? is it A = torch.nn.Parameter(torch.from_numpy(adj).float(),requires_grad=True) ?

@monk1337 monk1337 changed the title How to input adj and features as pytorch Parameter to PyG? How to obtain gradients for input node features and graph structure? Sep 21, 2021
@rusty1s
Copy link
Member

rusty1s commented Sep 21, 2021

A would just be your dense adjacency matrix, e.g., adj = torch.rand(0,2,[5,5])

@monk1337
Copy link
Contributor Author

monk1337 commented Sep 21, 2021

@rusty1s Thank you for quick reply, As we discussed here I am using this GCN network in deterministic operations setting, where I am using SparseTensor.

@monk1337
Copy link
Contributor Author

For example

import torch
import numpy as np
from torch_sparse import SparseTensor

adj        = torch.Tensor(np.random.randint(0,2, [2,2]))
feat       = torch.Tensor(np.random.uniform(-1,1,[2,5]))

edge_index = adj.nonzero(as_tuple=False).t()
adj        = SparseTensor(row=edge_index[0], col=edge_index[1],
                   sparse_sizes=(2,2))

model     = GCN()
out       = model(feat, adj.t())

How I can obtain gradients for input node features and graph structure in this SparseTensor case ?

@rusty1s
Copy link
Member

rusty1s commented Sep 22, 2021

You can wrap feat and the non-zero values of adj into Parameters:

import torch
from torch_sparse import SparseTensor
from torch_geometric.nn import GCN

adj = torch.rand([2, 2])
feat = torch.randn(2, 5)
feat = torch.nn.Parameter(feat)

edge_index = adj.nonzero(as_tuple=False).t()
edge_weight = torch.nn.Parameter(adj[edge_index[0], edge_index[1]])
adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_weight,
                   sparse_sizes=(2, 2))

model = GCN(5, 16, num_layers=2)
out = model(feat, adj.t())
out.mean().backward()

print(feat.grad)
print(edge_weight.grad)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants