Skip to content

Commit

Permalink
nx-cugraph: add ego_graph (#4395)
Browse files Browse the repository at this point in the history
Authors:
  - Erik Welch (https://github.com/eriknw)

Approvers:
  - Rick Ratzel (https://github.com/rlratzel)

URL: #4395
  • Loading branch information
eriknw authored May 21, 2024
1 parent b9f6e8c commit 624e961
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 9 deletions.
2 changes: 2 additions & 0 deletions python/nx-cugraph/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ Below is the list of algorithms that are currently supported in nx-cugraph.
└─ <a href="https://networkx.org/documentation/stable/reference/generated/networkx.generators.classic.wheel_graph.html#networkx.generators.classic.wheel_graph">wheel_graph</a>
<a href="https://networkx.org/documentation/stable/reference/generators.html#module-networkx.generators.community">community</a>
└─ <a href="https://networkx.org/documentation/stable/reference/generated/networkx.generators.community.caveman_graph.html#networkx.generators.community.caveman_graph">caveman_graph</a>
<a href="https://networkx.org/documentation/stable/reference/generators.html#module-networkx.generators.ego">ego</a>
└─ <a href="https://networkx.org/documentation/stable/reference/generated/networkx.generators.ego.ego_graph.html#networkx.generators.ego.ego_graph">ego_graph</a>
<a href="https://networkx.org/documentation/stable/reference/generators.html#module-networkx.generators.small">small</a>
├─ <a href="https://networkx.org/documentation/stable/reference/generated/networkx.generators.small.bull_graph.html#networkx.generators.small.bull_graph">bull_graph</a>
├─ <a href="https://networkx.org/documentation/stable/reference/generated/networkx.generators.small.chvatal_graph.html#networkx.generators.small.chvatal_graph">chvatal_graph</a>
Expand Down
5 changes: 5 additions & 0 deletions python/nx-cugraph/_nx_cugraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"diamond_graph",
"dodecahedral_graph",
"edge_betweenness_centrality",
"ego_graph",
"eigenvector_centrality",
"empty_graph",
"florentine_families_graph",
Expand Down Expand Up @@ -163,6 +164,7 @@
"clustering": "Directed graphs and `weight` parameter are not yet supported.",
"core_number": "Directed graphs are not yet supported.",
"edge_betweenness_centrality": "`weight` parameter is not yet supported, and RNG with seed may be different.",
"ego_graph": "Weighted ego_graph with negative cycles is not yet supported. `NotImplementedError` will be raised if there are negative `distance` edge weights.",
"eigenvector_centrality": "`nstart` parameter is not used, but it is checked for validity.",
"from_pandas_edgelist": "cudf.DataFrame inputs also supported; value columns with str is unsuppported.",
"generic_bfs_edges": "`neighbors` and `sort_neighbors` parameters are not yet supported.",
Expand Down Expand Up @@ -191,6 +193,9 @@
"bellman_ford_path_length": {
"dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.",
},
"ego_graph": {
"dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.",
},
"eigenvector_centrality": {
"dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.",
},
Expand Down
8 changes: 4 additions & 4 deletions python/nx-cugraph/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.16
rev: v0.17
hooks:
- id: validate-pyproject
name: Validate pyproject.toml
Expand All @@ -50,7 +50,7 @@ repos:
- id: black
# - id: black-jupyter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.2
rev: v0.4.4
hooks:
- id: ruff
args: [--fix-only, --show-fixes] # --unsafe-fixes]
Expand All @@ -62,7 +62,7 @@ repos:
additional_dependencies: &flake8_dependencies
# These versions need updated manually
- flake8==7.0.0
- flake8-bugbear==24.4.21
- flake8-bugbear==24.4.26
- flake8-simplify==0.21.0
- repo: https://github.com/asottile/yesqa
rev: v1.5.0
Expand All @@ -77,7 +77,7 @@ repos:
additional_dependencies: [tomli]
files: ^(nx_cugraph|docs)/
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.2
rev: v0.4.4
hooks:
- id: ruff
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
9 changes: 7 additions & 2 deletions python/nx-cugraph/nx_cugraph/convert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -562,7 +562,12 @@ def to_networkx(G: nxcg.Graph, *, sort_edges: bool = False) -> nx.Graph:
dst_iter = map(id_to_key.__getitem__, dst_indices)
if G.is_multigraph() and (G.edge_keys is not None or G.edge_indices is not None):
if G.edge_keys is not None:
edge_keys = G.edge_keys
if not G.is_directed():
edge_keys = [k for k, m in zip(G.edge_keys, mask.tolist()) if m]
else:
edge_keys = G.edge_keys
elif not G.is_directed():
edge_keys = G.edge_indices[mask].tolist()
else:
edge_keys = G.edge_indices.tolist()
if edge_values:
Expand Down
3 changes: 2 additions & 1 deletion python/nx-cugraph/nx_cugraph/generators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -12,5 +12,6 @@
# limitations under the License.
from .classic import *
from .community import *
from .ego import *
from .small import *
from .social import *
161 changes: 161 additions & 0 deletions python/nx-cugraph/nx_cugraph/generators/ego.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

import cupy as cp
import networkx as nx
import numpy as np
import pylibcugraph as plc

import nx_cugraph as nxcg

from ..utils import _dtype_param, _get_float_dtype, index_dtype, networkx_algorithm

__all__ = ["ego_graph"]


@networkx_algorithm(
extra_params=_dtype_param, version_added="24.06", _plc={"bfs", "ego_graph", "sssp"}
)
def ego_graph(
G, n, radius=1, center=True, undirected=False, distance=None, *, dtype=None
):
"""Weighted ego_graph with negative cycles is not yet supported. `NotImplementedError` will be raised if there are negative `distance` edge weights.""" # noqa: E501
if isinstance(G, nx.Graph):
G = nxcg.from_networkx(G, preserve_all_attrs=True)
if n not in G:
if distance is None:
raise nx.NodeNotFound(f"Source {n} is not in G")
raise nx.NodeNotFound(f"Node {n} not found in graph")
src_index = n if G.key_to_id is None else G.key_to_id[n]
symmetrize = "union" if undirected and G.is_directed() else None
if distance is None or distance not in G.edge_values:
# Simple BFS to determine nodes
if radius is not None and radius <= 0:
if center:
node_ids = cp.array([src_index], dtype=index_dtype)
else:
node_ids = cp.empty(0, dtype=index_dtype)
node_mask = None
else:
if radius is None or np.isinf(radius):
radius = -1
else:
radius = math.ceil(radius)
distances, unused_predecessors, node_ids = plc.bfs(
handle=plc.ResourceHandle(),
graph=G._get_plc_graph(symmetrize=symmetrize),
sources=cp.array([src_index], index_dtype),
direction_optimizing=False, # True for undirected only; what's best?
depth_limit=radius,
compute_predecessors=False,
do_expensive_check=False,
)
node_mask = distances != np.iinfo(distances.dtype).max
else:
# SSSP to determine nodes
if callable(distance):
raise NotImplementedError("callable `distance` argument is not supported")
if symmetrize and G.is_multigraph():
# G._get_plc_graph does not implement `symmetrize=True` w/ edge array
raise NotImplementedError(
"Weighted ego_graph with undirected=True not implemented"
)
# Check for negative values since we don't support negative cycles
edge_vals = G.edge_values[distance]
if distance in G.edge_masks:
edge_vals = edge_vals[G.edge_masks[distance]]
if (edge_vals < 0).any():
raise NotImplementedError(
"Negative edge weights not yet supported by ego_graph"
)
# PERF: we could use BFS if all edges are equal
if radius is None:
radius = np.inf
dtype = _get_float_dtype(dtype, graph=G, weight=distance)
node_ids, distances, unused_predecessors = plc.sssp(
resource_handle=plc.ResourceHandle(),
graph=(G.to_undirected() if symmetrize else G)._get_plc_graph(
distance, 1, dtype
),
source=src_index,
cutoff=np.nextafter(radius, np.inf, dtype=np.float64),
compute_predecessors=True, # TODO: False is not yet supported
do_expensive_check=False,
)
node_mask = distances != np.finfo(distances.dtype).max

if node_mask is not None:
if not center:
node_mask &= node_ids != src_index
node_ids = node_ids[node_mask]
if node_ids.size == G._N:
return G.copy()
# TODO: create renumbering helper function(s)
node_ids.sort() # TODO: is this ever necessary? Keep for safety
node_values = {key: val[node_ids] for key, val in G.node_values.items()}
node_masks = {key: val[node_ids] for key, val in G.node_masks.items()}

G._sort_edge_indices() # TODO: is this ever necessary? Keep for safety
edge_mask = cp.isin(G.src_indices, node_ids) & cp.isin(G.dst_indices, node_ids)
src_indices = cp.searchsorted(node_ids, G.src_indices[edge_mask]).astype(
index_dtype
)
dst_indices = cp.searchsorted(node_ids, G.dst_indices[edge_mask]).astype(
index_dtype
)
edge_values = {key: val[edge_mask] for key, val in G.edge_values.items()}
edge_masks = {key: val[edge_mask] for key, val in G.edge_masks.items()}

# Renumber nodes
if (id_to_key := G.id_to_key) is not None:
key_to_id = {
id_to_key[old_index]: new_index
for new_index, old_index in enumerate(node_ids.tolist())
}
else:
key_to_id = {
old_index: new_index
for new_index, old_index in enumerate(node_ids.tolist())
}
kwargs = {
"N": node_ids.size,
"src_indices": src_indices,
"dst_indices": dst_indices,
"edge_values": edge_values,
"edge_masks": edge_masks,
"node_values": node_values,
"node_masks": node_masks,
"key_to_id": key_to_id,
}
if G.is_multigraph():
if G.edge_keys is not None:
kwargs["edge_keys"] = [
x for x, m in zip(G.edge_keys, edge_mask.tolist()) if m
]
if G.edge_indices is not None:
kwargs["edge_indices"] = G.edge_indices[edge_mask]
rv = G.__class__.from_coo(**kwargs)
rv.graph.update(G.graph)
return rv


@ego_graph._can_run
def _(G, n, radius=1, center=True, undirected=False, distance=None, *, dtype=None):
if distance is not None and undirected and G.is_directed() and G.is_multigraph():
return "Weighted ego_graph with undirected=True not implemented"
if distance is not None and nx.is_negatively_weighted(G, weight=distance):
return "Weighted ego_graph with negative cycles not yet supported"
if callable(distance):
return "callable `distance` argument is not supported"
return True
81 changes: 81 additions & 0 deletions python/nx-cugraph/nx_cugraph/tests/test_ego_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import networkx as nx
import pytest
from packaging.version import parse

import nx_cugraph as nxcg

from .testing_utils import assert_graphs_equal

nxver = parse(nx.__version__)


if nxver.major == 3 and nxver.minor < 2:
pytest.skip("Need NetworkX >=3.2 to test ego_graph", allow_module_level=True)


@pytest.mark.parametrize(
"create_using", [nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph]
)
@pytest.mark.parametrize("radius", [-1, 0, 1, 1.5, 2, float("inf"), None])
@pytest.mark.parametrize("center", [True, False])
@pytest.mark.parametrize("undirected", [False, True])
@pytest.mark.parametrize("multiple_edges", [False, True])
@pytest.mark.parametrize("n", [0, 3])
def test_ego_graph_cycle_graph(
create_using, radius, center, undirected, multiple_edges, n
):
Gnx = nx.cycle_graph(7, create_using=create_using)
if multiple_edges:
# Test multigraph with multiple edges
if not Gnx.is_multigraph():
return
Gnx.add_edges_from(nx.cycle_graph(7, create_using=nx.DiGraph).edges)
Gnx.add_edge(0, 1, 10)
Gcg = nxcg.from_networkx(Gnx, preserve_all_attrs=True)
assert_graphs_equal(Gnx, Gcg) # Sanity check

kwargs = {"radius": radius, "center": center, "undirected": undirected}
Hnx = nx.ego_graph(Gnx, n, **kwargs)
Hcg = nx.ego_graph(Gnx, n, **kwargs, backend="cugraph")
assert_graphs_equal(Hnx, Hcg)
with pytest.raises(nx.NodeNotFound, match="not in G"):
nx.ego_graph(Gnx, -1, **kwargs)
with pytest.raises(nx.NodeNotFound, match="not in G"):
nx.ego_graph(Gnx, -1, **kwargs, backend="cugraph")
# Using sssp with default weight of 1 should give same answer as bfs
nx.set_edge_attributes(Gnx, 1, name="weight")
Gcg = nxcg.from_networkx(Gnx, preserve_all_attrs=True)
assert_graphs_equal(Gnx, Gcg) # Sanity check

kwargs["distance"] = "weight"
H2nx = nx.ego_graph(Gnx, n, **kwargs)
is_nx32 = nxver.major == 3 and nxver.minor == 2
if undirected and Gnx.is_directed() and Gnx.is_multigraph():
if is_nx32:
# `should_run` was added in nx 3.3
match = "Weighted ego_graph with undirected=True not implemented"
else:
match = "not implemented by cugraph"
with pytest.raises(RuntimeError, match=match):
nx.ego_graph(Gnx, n, **kwargs, backend="cugraph")
with pytest.raises(NotImplementedError, match="ego_graph"):
nx.ego_graph(Gcg, n, **kwargs)
else:
H2cg = nx.ego_graph(Gnx, n, **kwargs, backend="cugraph")
assert_graphs_equal(H2nx, H2cg)
with pytest.raises(nx.NodeNotFound, match="not found in graph"):
nx.ego_graph(Gnx, -1, **kwargs)
with pytest.raises(nx.NodeNotFound, match="not found in graph"):
nx.ego_graph(Gnx, -1, **kwargs, backend="cugraph")
5 changes: 3 additions & 2 deletions python/nx-cugraph/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.

[build-system]

Expand All @@ -19,7 +19,7 @@ authors = [
license = { text = "Apache 2.0" }
requires-python = ">=3.9"
classifiers = [
"Development Status :: 3 - Alpha",
"Development Status :: 4 - Beta",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
Expand Down Expand Up @@ -233,6 +233,7 @@ ignore = [
"nx_cugraph/**/tests/*py" = ["S101", "S311", "T201", "D103", "D100"]
"_nx_cugraph/__init__.py" = ["E501"]
"nx_cugraph/algorithms/**/*py" = ["D205", "D401"] # Allow flexible docstrings for algorithms
"nx_cugraph/generators/**/*py" = ["D205", "D401"] # Allow flexible docstrings for generators
"nx_cugraph/interface.py" = ["D401"] # Flexible docstrings
"scripts/update_readme.py" = ["INP001"] # Not part of a package

Expand Down

0 comments on commit 624e961

Please sign in to comment.