Skip to content

Commit

Permalink
Use Vertex ID Offsets in CuGraphStorage (#2996)
Browse files Browse the repository at this point in the history
Users Vertex ID offsets in `CuGraphStore`.  Offers the option of using pre-computed offsets by setting `renumber_graph=False`.  Default behavior is to warn the user that it is computing offsets and save the original vertex ids in a new column.  If `renumber_graph=True`, there is no warning, `renumber_vertices_by_type()` is called, and the vertex ids are overwritten.

Due to a bug in `MGPropertyGraph`, if vertices are renumbered, edges must also be renumbered.

Closes rapidsai/graph_dl#95
Closes #3069 

- [x] Update Remote Graph tests
- [x] Update CGS e2e tests
- [x] Update SG PyG Extension Tests
- [x] Update MG PyG Extension Tests
- [x] Verify tests that are currently broken in CI

Includes fix for #3069 
Merge after #3012

Authors:
  - Alex Barghi (https://github.com/alexbarghi-nv)
  - Erik Welch (https://github.com/eriknw)
  - Rick Ratzel (https://github.com/rlratzel)
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Rick Ratzel (https://github.com/rlratzel)
  - Vibhu Jawa (https://github.com/VibhuJawa)

URL: #2996
  • Loading branch information
alexbarghi-nv authored Jan 12, 2023
1 parent c98da66 commit ef52c5c
Show file tree
Hide file tree
Showing 19 changed files with 1,118 additions and 266 deletions.
183 changes: 155 additions & 28 deletions python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -14,13 +14,32 @@
from typing import Optional, Tuple, Any
from enum import Enum

import cupy

from dataclasses import dataclass
from collections import defaultdict
from itertools import chain
from functools import cached_property

import warnings

# numpy is always available
import numpy as np

# cuGraph or cuGraph-Service is required; each has its own version of
# import_optional and we need to select the correct one.
try:
from cugraph_service.client.remote_graph_utils import import_optional
except ModuleNotFoundError:
try:
from cugraph.utilities.utils import import_optional
except ModuleNotFoundError:
raise ModuleNotFoundError(
"cuGraph-PyG requires cuGraph" "or cuGraph-Service to be installed."
)

# FIXME drop cupy support and make torch the only backend (#2995)
cupy = import_optional("cupy")
torch = import_optional("torch")


class EdgeLayout(Enum):
COO = "coo"
Expand Down Expand Up @@ -74,7 +93,7 @@ def cast(cls, *args, **kwargs):
return cls(*args, **kwargs)


def EXPERIMENTAL__to_pyg(G, backend="torch"):
def EXPERIMENTAL__to_pyg(G, backend="torch", renumber_graph=None):
"""
Returns the PyG wrappers for the provided PropertyGraph or
MGPropertyGraph.
Expand All @@ -83,13 +102,20 @@ def EXPERIMENTAL__to_pyg(G, backend="torch"):
----------
G : PropertyGraph or MGPropertyGraph
The graph to produce PyG wrappers for.
renumber_graph: bool
Should usually be set to True. If True, the vertices and edges
in the provided property graph will be renumbered so that they
are contiguous by type. If the vertices and edges are already
contiguously renumbered by type, then this can be set to False.
Returns
-------
Tuple (CuGraphStore, CuGraphStore)
Wrappers for the provided property graph.
"""
store = EXPERIMENTAL__CuGraphStore(G, backend=backend)
store = EXPERIMENTAL__CuGraphStore(
G, backend=backend, renumber_graph=renumber_graph
)
return (store, store)


Expand Down Expand Up @@ -178,7 +204,7 @@ class EXPERIMENTAL__CuGraphStore:
Duck-typed version of PyG's GraphStore and FeatureStore.
"""

def __init__(self, G, reserved_keys=[], backend="torch"):
def __init__(self, G, backend="torch", renumber_graph=None):
"""
Constructs a new CuGraphStore from the provided
arguments.
Expand All @@ -188,27 +214,35 @@ def __init__(self, G, reserved_keys=[], backend="torch"):
G : PropertyGraph or MGPropertyGraph
The cuGraph property graph where the
data is being stored.
reserved_keys : Properties in the graph that are not used for
training (the 'x' attribute will ignore these properties).
backend : The backend that manages tensors (default = 'torch')
backend : ('torch', 'cupy')
The backend that manages tensors (default = 'torch')
Should usually be 'torch' ('torch', 'cupy' supported).
renumber_graph : bool
If True, will renumber vertices and edges to have contiguous
ids per type. If False, will not renumber vertices. If not
specified, will renumber and raise a warning.
"""

# TODO ensure all x properties are float32 type
# TODO ensure y is of long type
# FIXME ensure all x properties are float32 type
# FIXME ensure y is of long type
if None in G.edge_types:
raise ValueError("Unspecified edge types not allowed in PyG")

# FIXME drop the cupy backend and remove these checks (#2995)
if backend == "torch":
from torch.utils.dlpack import from_dlpack
from torch import int64 as vertex_dtype
from torch import float32 as property_dtype
from torch import searchsorted as searchsorted
from torch import concatenate as concatenate
from torch import arange as arange
elif backend == "cupy":
from cupy import from_dlpack
from cupy import int64 as vertex_dtype
from cupy import float32 as property_dtype
from cupy import searchsorted as searchsorted
from cupy import concatenate as concatenate
from cupy import arange as arange
else:
raise ValueError(f"Invalid backend {backend}.")

Expand All @@ -217,19 +251,21 @@ def __init__(self, G, reserved_keys=[], backend="torch"):
self.vertex_dtype = vertex_dtype
self.property_dtype = property_dtype
self.searchsorted = searchsorted
self.concatenate = concatenate
self.arange = arange

self.__graph = G
self.__subgraphs = {}

self.__reserved_keys = [
self.__graph.type_col_name,
self.__graph.vertex_col_name,
] + list(reserved_keys)

self._tensor_attr_cls = CuGraphTensorAttr
self._tensor_attr_dict = defaultdict(list)
self.__infer_x_and_y_tensors()

# Must be called after __infer_x_and_y_tensors to
# avoid adding the old vertex id as a property when
# users do not specify it.
self.__renumber_graph(renumber_graph)

self.__edge_types_to_attrs = {}
for edge_type in self.__graph.edge_types:
edges = self.__graph.get_edge_data(types=[edge_type])
Expand Down Expand Up @@ -269,6 +305,92 @@ def __init__(self, G, reserved_keys=[], backend="torch"):

self._edge_attr_cls = CuGraphEdgeAttr

def __renumber_graph(self, renumber_graph):
"""
Renumbers the vertices and edges in this store's property graph
and sets the vertex offsets.
If renumber_graph is False, then renumber_vertices_by_type()
and renumber_edges_by_type()
are not called and the offsets are inferred from vertex counts.
If renumber_graph is None, it defaults to True, warns the
user of this default behavior, and saves the current ids as
<vertex_col>_old.
If renumber_graph is True, it calls renumber_vertices_by_type()
and renumber_edges_by_type(),
overwriting the current vertex and edge ids without saving them.
"""
self.__old_vertex_col_name = None
self.__old_edge_col_name = None

if renumber_graph is None:
renumber_graph = True
self.__old_vertex_col_name = f"{self.__graph.vertex_col_name}_old"
self.__old_edge_col_name = f"{self.__graph.edge_id_col_name}_old"
warnings.warn(
f"renumber_graph not specified; renumbering by default "
f"and saving as {self.__old_vertex_col_name} "
f"and {self.__old_edge_col_name}"
)

if renumber_graph:
if self.is_remote and self.backend == "torch":
self.__vertex_type_offsets = self.__graph.renumber_vertices_by_type(
prev_id_column=self.__old_vertex_col_name,
backend="torch:cuda" if torch.has_cuda else "torch",
)
else:
self.__vertex_type_offsets = self.__graph.renumber_vertices_by_type(
prev_id_column=self.__old_vertex_col_name
)

# FIXME: https://github.com/rapidsai/cugraph/issues/3059
# Currently renumbering edges is required if renumbering vertices or else
# there is a dask partitioning issue.
self.__graph.renumber_edges_by_type(prev_id_column=self.__old_edge_col_name)

else:
self.__vertex_type_offsets = {}
self.__vertex_type_offsets["stop"] = [
self.__graph.get_num_vertices(vt) for vt in self.__graph.vertex_types
]
if self.__backend == "cupy":
self.__vertex_type_offsets["stop"] = cupy.array(
self.__vertex_type_offsets["stop"]
)
else:
self.__vertex_type_offsets["stop"] = torch.tensor(
self.__vertex_type_offsets["stop"]
)
if torch.has_cuda:
self.__vertex_type_offsets["stop"] = self.__vertex_type_offsets[
"stop"
].cuda()

cumsum = self.__vertex_type_offsets["stop"].cumsum(0)
self.__vertex_type_offsets["start"] = (
self.__vertex_type_offsets["stop"] - cumsum
)
self.__vertex_type_offsets["stop"] -= 1
self.__vertex_type_offsets["type"] = np.array(self.__graph.vertex_types)

@property
def _old_vertex_col_name(self):
"""
Returns the name of the new property in the wrapped property graph where
the original vertex ids were stored, if this store did its own renumbering.
"""
return self.__old_vertex_col_name

@property
def _old_edge_col_name(self):
"""
Returns the name of the new property in the wrapped property graph where
the original edge ids were stored, if this store did its own renumbering.
"""
return self.__old_edge_col_name

@property
def _edge_types_to_attrs(self):
return dict(self.__edge_types_to_attrs)
Expand All @@ -290,25 +412,29 @@ def is_multi_gpu(self):

@cached_property
def is_remote(self):
return self.__graph.is_remote()
pg_types = ["PropertyGraph", "MGPropertyGraph"]
if type(self.__graph).__name__ in pg_types:
return False
else:
return self.__graph.is_remote()

@cached_property
def _is_delayed(self):
return self.is_multi_gpu and not self.is_remote

def get_vertex_index(self, vtypes):
# TODO force the graph to use offsets and
# return these values based on offsets

if isinstance(vtypes, str):
vtypes = [vtypes]

ix = self.__graph.get_vertex_data(
types=vtypes, columns=[self.__graph.type_col_name]
)[self.__graph.vertex_col_name]

if self._is_delayed:
ix = ix.compute()
# FIXME always use torch, drop cupy (#2995)
if self.__backend == "torch":
ix = torch.tensor()
else:
ix = cupy.array()
for vtype in vtypes:
start = self.__vertex_type_offsets["start"][vtype]
stop = self.__vertex_type_offsets["stop"][vtype]
ix = self.concatenate(ix, self.arange(start, stop + 1, 1))

return self.from_dlpack(ix.to_dlpack())

Expand Down Expand Up @@ -461,9 +587,10 @@ def _subgraph(self, edge_types):
edge_types = tuple(sorted(edge_types))

if edge_types not in self.__subgraphs:
query = f'(_TYPE_=="{edge_types[0]}")'
TCN = self.__graph.type_col_name
query = f'({TCN}=="{edge_types[0]}")'
for t in edge_types[1:]:
query += f' | (_TYPE_=="{t}")'
query += f' | ({TCN}=="{t}")'
selection = self.__graph.select_edges(query)

# FIXME enforce int type
Expand Down
5 changes: 2 additions & 3 deletions python/cugraph-pyg/cugraph_pyg/loader/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -20,8 +20,7 @@
from cugraph.utilities.utils import import_optional
except ModuleNotFoundError:
raise ModuleNotFoundError(
"cuGraph extensions for PyG require cuGraph"
"or cuGraph-Service to be installed."
"cuGraph-PyG requires cugraph or cugraph-service-client to be installed."
)

_transform_to_backend_dtype_1d = import_optional(
Expand Down
68 changes: 68 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pytest

from dask_cuda.initialize import initialize as dask_initialize
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
from cugraph.dask.comms import comms as Comms
from cugraph.dask.common.mg_utils import get_visible_devices
from cugraph.testing.mg_utils import stop_dask_client

import tempfile

# module-wide fixtures

# If the rapids-pytest-benchmark plugin is installed, the "gpubenchmark"
# fixture will be available automatically. Check that this fixture is available
# by trying to import rapids_pytest_benchmark, and if that fails, set
# "gpubenchmark" to the standard "benchmark" fixture provided by
# pytest-benchmark.
try:
import rapids_pytest_benchmark # noqa: F401
except ImportError:
import pytest_benchmark

gpubenchmark = pytest_benchmark.plugin.benchmark


@pytest.fixture(scope="module")
def dask_client():
dask_scheduler_file = os.environ.get("SCHEDULER_FILE")
cuda_visible_devices = get_visible_devices()

if dask_scheduler_file is not None:
dask_initialize()
dask_client = Client(scheduler_file=dask_scheduler_file)
else:
# The tempdir created by tempdir_object should be cleaned up once
# tempdir_object goes out-of-scope and is deleted.
tempdir_object = tempfile.TemporaryDirectory()
cluster = LocalCUDACluster(
local_directory=tempdir_object.name,
protocol="tcp",
CUDA_VISIBLE_DEVICES=cuda_visible_devices,
)

dask_client = Client(cluster)
dask_client.wait_for_workers(len(cuda_visible_devices))

if not Comms.is_initialized():
Comms.initialize(p2p=True)

yield dask_client

stop_dask_client(dask_client)
print("\ndask_client fixture: client.close() called")
Loading

0 comments on commit ef52c5c

Please sign in to comment.