Skip to content

Commit

Permalink
Fix lpformer tests and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
HarryShomer committed Jan 16, 2025
1 parent c2bc98f commit d992ff2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
2 changes: 2 additions & 0 deletions test/nn/models/test_lpformer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import torch

import torch_geometric.typing
from torch_geometric.testing import withPackage
from torch_geometric.nn import LPFormer
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import to_undirected


@withPackage('numba') # For ppr calculation
def test_lpformer():
model = LPFormer(16, 32, num_gnn_layers=2,
num_transformer_layers=1)
Expand Down
21 changes: 10 additions & 11 deletions torch_geometric/nn/models/lpformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def forward(
edge_index: Adj,
ppr_matrix: Tensor,
) -> Tensor:
r"""Forward Pass of LPFormer
r"""Forward Pass of LPFormer.
Returns raw logits for each link
Expand Down Expand Up @@ -184,7 +184,7 @@ def forward(

def propagate(self, x: Tensor, adj: Adj) -> Tensor:
"""
Propagate via GNN
Propagate via GNN.
Args:
x (Tensor): Node features
Expand All @@ -198,7 +198,7 @@ def propagate(self, x: Tensor, adj: Adj) -> Tensor:

def calc_pairwise(self, batch: Tensor, X_node: Tensor, adj_mask: Tensor,
ppr_matrix: Tensor) -> Tensor:
r"""Calculate the pairwise features for the node pairs
r"""Calculate the pairwise features for the node pairs.
Args:
batch (Tensor): The batch vector.
Expand Down Expand Up @@ -241,7 +241,7 @@ def get_pos_encodings(
self, cn_ppr: Tuple[Tensor, Tensor],
onehop_ppr: Optional[Tuple[Tensor, Tensor]] = None,
non1hop_ppr: Optional[Tuple[Tensor, Tensor]] = None) -> Tensor:
r"""Calculate the PPR-based relative positional encodings
r"""Calculate the PPR-based relative positional encodings.
Due to thresholds, sometimes we don't have 1-hop or >1-hop nodes.
In those cases, the value of onehop_ppr and/or non1hop_ppr should
Expand Down Expand Up @@ -281,8 +281,7 @@ def get_pos_encodings(
def compute_node_mask(
self, batch: Tensor, adj: Tensor, ppr_matrix: Tensor
) -> Tuple[Tuple, Optional[Tuple], Optional[Tuple]]:
r"""
Get mask based on type of node
r"""Get mask based on type of node.
When mask_type is not "cn", also return the ppr vals for both
the source and target.
Expand Down Expand Up @@ -447,7 +446,7 @@ def get_structure_cnts(
non1hop_info: Optional[Tuple[Tensor, Tensor]],
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Counts for CNs, 1-Hop, and >1-Hop that satisfy PPR threshold
Counts for CNs, 1-Hop, and >1-Hop that satisfy PPR threshold.
Also include total # of neighbors
Expand Down Expand Up @@ -484,7 +483,7 @@ def get_num_ppr_thresh(self, batch: Tensor, node_mask: Tensor,
src_ppr: Tensor, tgt_ppr: Tensor,
thresh: float) -> Tensor:
"""
Get # of nodes `v` where `ppr(a, v) >= eta` & `ppr(b, v) >= eta`
Get # of nodes `v` where `ppr(a, v) >= eta` & `ppr(b, v) >= eta`.
Args:
batch (Tensor): The batch vector.
Expand All @@ -510,7 +509,7 @@ def get_count(
batch: Tensor,
) -> Tensor:
"""
# of nodes for each sample in batch
# of nodes for each sample in batch.
They node have already filtered by PPR beforehand
Expand Down Expand Up @@ -598,7 +597,7 @@ def get_non_1hop_ppr(self, batch: Tensor, adj: Tensor,
def calc_sparse_ppr(self, edge_index: Tensor, num_nodes: int,
alpha: float = 0.15, eps: float = 5e-5) -> Tensor:
r"""
Calculate the PPR of the graph in sparse format
Calculate the PPR of the graph in sparse format.
Args:
edge_index: The edge indices
Expand Down Expand Up @@ -748,7 +747,7 @@ def message(self, x_i: Tensor, x_j: Tensor, ppr_rpes: Tensor,

class MLP(nn.Module):
"""
L Layer MLP
L Layer MLP.
"""
def __init__(self, in_channels: int, hid_channels: int, out_channels: int,
num_layers: int = 2, drop: int = 0, norm: str = "layer"):
Expand Down

0 comments on commit d992ff2

Please sign in to comment.