Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Vertex ID Offsets in CuGraphStorage #2996

Merged
merged 183 commits into from
Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from 169 commits
Commits
Show all changes
183 commits
Select commit Hold shift + click to select a range
2a6c9cf
PropertyGraph set index to vertex and edge ids
eriknw Aug 9, 2022
4c93f77
Update graph_store
eriknw Aug 10, 2022
2631715
flake8
eriknw Aug 10, 2022
ff0f80c
Merge branch 'branch-22.10' into pg_set_index
eriknw Sep 7, 2022
99c2e0e
Set index to vertex or edge IDs in PG for MG
eriknw Sep 14, 2022
1a1e039
Merge branch 'pg_set_index' of https://github.com/eriknw/cugraph into…
alexbarghi-nv Sep 21, 2022
9bbf048
fixes
alexbarghi-nv Sep 21, 2022
4496cba
merge
alexbarghi-nv Oct 12, 2022
ccae80b
Fix concat with different index dtypes in SG PropertyGraph
eriknw Oct 13, 2022
824d083
initial
alexbarghi-nv Oct 17, 2022
5aaa90d
initial work on remote wrappers, very rough
alexbarghi-nv Oct 18, 2022
52fe830
merge resolution
alexbarghi-nv Oct 18, 2022
3221911
additional functionality, v/e counts
alexbarghi-nv Oct 18, 2022
f097043
copyright update
alexbarghi-nv Oct 18, 2022
7d33ed6
additional functions
alexbarghi-nv Oct 18, 2022
d14ae24
quick fix
alexbarghi-nv Oct 18, 2022
10bf725
Merge branch 'sgpg_fix_concat' of https://github.com/eriknw/cugraph i…
alexbarghi-nv Oct 18, 2022
1887ce7
add definition for remote graph, tests for pg
alexbarghi-nv Oct 18, 2022
f598dbe
remove dispatch (will be added in other pr)
alexbarghi-nv Oct 18, 2022
6b34f5b
Merge branch 'branch-22.12' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv Oct 18, 2022
8495b70
revert inadvertently changed file
alexbarghi-nv Oct 18, 2022
7ffd777
Merge branch 'branch-22.12' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv Oct 19, 2022
5089def
initial changes
alexbarghi-nv Oct 20, 2022
c157076
update version
alexbarghi-nv Oct 20, 2022
bc400ca
Merge branch 'branch-22.12' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv Oct 20, 2022
ab3a28e
pull in dispatch from other branch
alexbarghi-nv Oct 20, 2022
438bfff
dispatch
alexbarghi-nv Oct 21, 2022
50fd0df
fix get_vertices(), add tests
alexbarghi-nv Oct 21, 2022
7fe9b0f
tests, fixes
alexbarghi-nv Oct 21, 2022
0a88a52
fix typo
alexbarghi-nv Oct 21, 2022
ae87b94
major changes to update output array/dataframe/tensor handling, unit/…
alexbarghi-nv Oct 25, 2022
ec44561
Merge branch 'cgs-remote-wrappers' of https://github.com/alexbarghi-n…
alexbarghi-nv Oct 25, 2022
c7d7112
fix merge conflict
alexbarghi-nv Oct 25, 2022
5703d41
fix version
alexbarghi-nv Oct 25, 2022
c8379ad
infer default backend
alexbarghi-nv Oct 25, 2022
3aed33d
fix default backend for remote pg
alexbarghi-nv Oct 25, 2022
ce12b47
reverse this commit
alexbarghi-nv Oct 25, 2022
092db5e
Revert "reverse this commit"
alexbarghi-nv Oct 25, 2022
e1a3c1f
Merge branch 'branch-22.12' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv Oct 25, 2022
a66e437
remove useless code from pg, remove print statement
alexbarghi-nv Oct 26, 2022
a18b336
move backend call to methods, add graph() factory, update tests
alexbarghi-nv Oct 26, 2022
865ca44
fix version
alexbarghi-nv Oct 26, 2022
0975038
fix get vertex/edge data with types in cgs handler, minor raii fix, u…
alexbarghi-nv Oct 26, 2022
8f28820
fix version
alexbarghi-nv Oct 26, 2022
c8289f6
update branch
alexbarghi-nv Oct 26, 2022
03a1cf2
minor fix
alexbarghi-nv Oct 26, 2022
9b2ff76
add loader fix initial code
alexbarghi-nv Oct 28, 2022
e007684
fixes
alexbarghi-nv Nov 1, 2022
e1d4b84
Resolve merge conflict
alexbarghi-nv Nov 1, 2022
bf3df8a
cleanup, fixes for renumbering
alexbarghi-nv Nov 1, 2022
7931832
support for pg api
alexbarghi-nv Nov 1, 2022
668f95f
sampling, algo calls, implicit sg, fixes for multigraph
alexbarghi-nv Nov 2, 2022
3fc56be
fix version
alexbarghi-nv Nov 2, 2022
4955f90
remove print statements
alexbarghi-nv Nov 2, 2022
5b2a917
Merge branch 'cgs-remote-sample' into loader_fix
alexbarghi-nv Nov 2, 2022
90f700f
resolve merge conflict
alexbarghi-nv Nov 7, 2022
c96be0a
fix version
alexbarghi-nv Nov 8, 2022
8dc069e
fix version
alexbarghi-nv Nov 9, 2022
64b7d82
rename columns
alexbarghi-nv Nov 9, 2022
be41c53
switch to import_optional
alexbarghi-nv Nov 9, 2022
8462a24
minor cleanup
alexbarghi-nv Nov 9, 2022
53020e2
prevent copy in numpy to numpy conversion
alexbarghi-nv Nov 9, 2022
ddfb89d
is_mg -> is_multi_gpu
alexbarghi-nv Nov 9, 2022
2548efd
point to new issue
alexbarghi-nv Nov 9, 2022
09a5be4
point to new issue
alexbarghi-nv Nov 9, 2022
260ab2e
Merge branch 'cgs-remote-sample' into loader_fix
alexbarghi-nv Nov 9, 2022
24f9c85
add fillna to property graph
alexbarghi-nv Nov 9, 2022
9716f3f
fix notebook
alexbarghi-nv Nov 9, 2022
261eb29
remove include code
alexbarghi-nv Nov 9, 2022
c6d49ba
update version
alexbarghi-nv Nov 9, 2022
091c958
test, doc updates
alexbarghi-nv Nov 9, 2022
37fc74d
add different check for sg/mg
alexbarghi-nv Nov 9, 2022
29d51f0
update is_mg calls
alexbarghi-nv Nov 9, 2022
3490c26
fix version
alexbarghi-nv Nov 9, 2022
8f8470e
Merge branch 'branch-22.12' into loader_fix
alexbarghi-nv Nov 9, 2022
43b2c5e
resolve conflict
alexbarghi-nv Nov 10, 2022
2fd134a
split fillna, remove 'inplace'
alexbarghi-nv Nov 14, 2022
e5c54d7
remove unwanted files
alexbarghi-nv Nov 14, 2022
d726c70
Merge branch 'branch-22.12' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv Nov 15, 2022
ac300c5
restore updated comments
alexbarghi-nv Nov 15, 2022
780fe9a
remove unwanted files
alexbarghi-nv Nov 15, 2022
bd3ee5a
formatting cleanup
alexbarghi-nv Nov 15, 2022
733b557
update the pyg extension tests
alexbarghi-nv Nov 15, 2022
e35fb25
test fix
alexbarghi-nv Nov 15, 2022
f0aebcc
update pyg tests
alexbarghi-nv Nov 15, 2022
fe59a9b
update notebook to use new fillna
alexbarghi-nv Nov 15, 2022
6f0f585
remove unwanted file
alexbarghi-nv Nov 15, 2022
0a2bf88
Merge branch 'branch-22.12' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv Nov 15, 2022
ddfd200
start nb tests
alexbarghi-nv Nov 15, 2022
7618f79
remove egg files
alexbarghi-nv Nov 15, 2022
68b0885
t
alexbarghi-nv Nov 15, 2022
1e5c015
clarify in docstring that general Series is accepted
alexbarghi-nv Nov 16, 2022
cdd3d47
clarify in docstring general Series accepted
alexbarghi-nv Nov 16, 2022
c576d7a
clean up formatting in test_property_graph
alexbarghi-nv Nov 16, 2022
d34a1c8
clean up formatting in test_mg_property_graph
alexbarghi-nv Nov 16, 2022
cdf836e
formatting fix for test_property_graph
alexbarghi-nv Nov 16, 2022
2d8cc1b
reformat
alexbarghi-nv Nov 16, 2022
82d7e85
Merge branch 'branch-22.12' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv Nov 16, 2022
41ef521
Merge branch 'loader_fix' into cgs-pyg
alexbarghi-nv Nov 17, 2022
1e4ae74
finish cgs support
alexbarghi-nv Nov 17, 2022
5e45d81
add back dropped file
alexbarghi-nv Nov 17, 2022
41f8067
clarify purpose of test
alexbarghi-nv Nov 17, 2022
a441aad
copyright fix
alexbarghi-nv Nov 17, 2022
5c92504
remove print statement
alexbarghi-nv Nov 21, 2022
a55e1be
remove egg
alexbarghi-nv Nov 21, 2022
68d8f0b
remove egg
alexbarghi-nv Nov 21, 2022
7dda411
resolve merge conflict
alexbarghi-nv Nov 21, 2022
c329883
remove print statements I thought had already been removed
alexbarghi-nv Nov 21, 2022
379181c
remove another print statement
alexbarghi-nv Nov 21, 2022
f4f9f7e
remove more print statements
alexbarghi-nv Nov 21, 2022
cf245bb
manual revert
alexbarghi-nv Nov 21, 2022
c09eeb6
manual revert #2
alexbarghi-nv Nov 21, 2022
b7313af
docstring formatting
alexbarghi-nv Nov 21, 2022
b8ee9d4
Revert "docstring formatting"
alexbarghi-nv Nov 21, 2022
b2e7ad6
Update cugraph_store.py
alexbarghi-nv Nov 21, 2022
b8eff8e
fix docstring
alexbarghi-nv Nov 21, 2022
ab8d489
update docstrings for remote_graph
alexbarghi-nv Nov 21, 2022
80377d2
style
alexbarghi-nv Nov 21, 2022
7ef8b33
accept empty property key list
alexbarghi-nv Nov 21, 2022
a106169
Merge branch 'branch-22.12' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv Nov 23, 2022
290d5d6
fix bug with empty weight list
alexbarghi-nv Nov 23, 2022
1ab8c22
fix style
alexbarghi-nv Nov 23, 2022
b0ebc43
remove unwanted files
alexbarghi-nv Nov 23, 2022
5f5bc22
revert tests
alexbarghi-nv Nov 23, 2022
7ffb895
remove unwanted files
alexbarghi-nv Nov 23, 2022
4b6ab32
run pre-commit
alexbarghi-nv Nov 23, 2022
319e2b6
Revert unintentional changes to scripts
alexbarghi-nv Nov 28, 2022
5c2bcdb
change compute_required to is_delayed
alexbarghi-nv Nov 28, 2022
591f4de
remove rmm pool
alexbarghi-nv Nov 28, 2022
732c151
add exception if unable to process array
alexbarghi-nv Nov 28, 2022
4dde9a5
add bool variable for module installation
alexbarghi-nv Nov 28, 2022
daade97
add cudf installed variable
alexbarghi-nv Nov 28, 2022
2d3f608
t
alexbarghi-nv Nov 29, 2022
c7dea6d
p
alexbarghi-nv Nov 29, 2022
7683ea4
cugraph store
alexbarghi-nv Nov 30, 2022
5a305c6
Added check for server subprocess exitcode while waiting for it to st…
rlratzel Nov 30, 2022
adb764a
Merge branch 'branch-23.02' into cp-offsets
alexbarghi-nv Nov 30, 2022
67308bf
Merge branch 'branch-23.02-cugraph_service_test_error_output' of http…
alexbarghi-nv Nov 30, 2022
123eb0c
update remote graph tests
alexbarghi-nv Nov 30, 2022
72294ca
update test_e2e
alexbarghi-nv Nov 30, 2022
948bb45
cgs
alexbarghi-nv Dec 5, 2022
79beaae
Merge branch 'branch-23.02' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv Dec 5, 2022
b39e3a4
reorg tests
alexbarghi-nv Dec 5, 2022
719c3f6
Merge branch 'branch-23.02' into cp-offsets
alexbarghi-nv Dec 6, 2022
85680b5
Always build without isolation.
vyasr Dec 6, 2022
f21c9d0
various changes to get tests to pass
alexbarghi-nv Dec 6, 2022
0c35cfd
Merge branch 'fix/devel_build_isolation' of https://github.com/vyasr/…
alexbarghi-nv Dec 6, 2022
a54d898
test updates
alexbarghi-nv Dec 8, 2022
19376b5
fix merge conflict
alexbarghi-nv Dec 15, 2022
cb791ac
pg dtype bugfix
alexbarghi-nv Dec 19, 2022
1af9495
fix mg dtype bug:
alexbarghi-nv Dec 19, 2022
5670bed
fix formatting issue
alexbarghi-nv Dec 19, 2022
6d6dfea
add ci test skip
alexbarghi-nv Dec 19, 2022
16fc189
fix renumbering test
alexbarghi-nv Dec 19, 2022
6bd78e6
fix tests
alexbarghi-nv Dec 19, 2022
99029b0
Merge branch 'branch-23.02' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv Dec 19, 2022
8475792
remove debug code
alexbarghi-nv Dec 20, 2022
2679be0
drop persist from get_vertex_data, change dtype only if necessary
alexbarghi-nv Dec 20, 2022
6b1dbd8
propagate change to get_edge_data
alexbarghi-nv Dec 20, 2022
8734e31
add type check, explanatory comments
alexbarghi-nv Dec 20, 2022
88dac25
docstring fix
alexbarghi-nv Dec 20, 2022
bef36ba
drop old index
alexbarghi-nv Dec 20, 2022
8e7122d
drop old index
alexbarghi-nv Dec 20, 2022
b744a8a
remove type check
alexbarghi-nv Dec 20, 2022
61be6bc
remove dtype check
alexbarghi-nv Dec 20, 2022
ea5d49c
remove persist again
alexbarghi-nv Dec 20, 2022
1b7f68f
revert cmake changes as problem should be resolved
alexbarghi-nv Dec 20, 2022
4772a05
Merge branch 'branch-23.02' into cp-offsets
alexbarghi-nv Jan 3, 2023
aead9a6
update copyright year
alexbarghi-nv Jan 3, 2023
b528fff
update conftest
alexbarghi-nv Jan 5, 2023
8e679ab
clarify docstring
alexbarghi-nv Jan 5, 2023
eada08e
remove print statement
alexbarghi-nv Jan 5, 2023
81a8676
do merge
alexbarghi-nv Jan 5, 2023
5e4b054
remove additions from graph api
alexbarghi-nv Jan 5, 2023
0310f39
Merge branch 'branch-23.02' into cp-offsets
alexbarghi-nv Jan 5, 2023
d02fb4f
fix style
alexbarghi-nv Jan 5, 2023
e0b367d
merge
alexbarghi-nv Jan 5, 2023
f3d1d51
update cugraph store
alexbarghi-nv Jan 5, 2023
8cdae58
fix cugraph store
alexbarghi-nv Jan 5, 2023
cd07a11
fixes
alexbarghi-nv Jan 6, 2023
035498e
merge
alexbarghi-nv Jan 6, 2023
5b8a843
fix style
alexbarghi-nv Jan 6, 2023
11485b5
Merge branch 'branch-23.02' into cp-offsets
alexbarghi-nv Jan 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 150 additions & 27 deletions python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -14,13 +14,32 @@
from typing import Optional, Tuple, Any
from enum import Enum

import cupy

from dataclasses import dataclass
from collections import defaultdict
from itertools import chain
from functools import cached_property

import warnings

# numpy is always available
import numpy as np

# cuGraph or cuGraph-Service is required; each has its own version of
# import_optional and we need to select the correct one.
try:
from cugraph_service.client.remote_graph_utils import import_optional
except ModuleNotFoundError:
try:
from cugraph.utilities.utils import import_optional
except ModuleNotFoundError:
raise ModuleNotFoundError(
"cuGraph-PyG requires cuGraph" "or cuGraph-Service to be installed."
)

# FIXME drop cupy support and make torch the only backend (#2995)
cupy = import_optional("cupy")
torch = import_optional("torch")


class EdgeLayout(Enum):
COO = "coo"
Expand Down Expand Up @@ -74,7 +93,7 @@ def cast(cls, *args, **kwargs):
return cls(*args, **kwargs)


def EXPERIMENTAL__to_pyg(G, backend="torch"):
def EXPERIMENTAL__to_pyg(G, backend="torch", renumber_graph=None):
"""
Returns the PyG wrappers for the provided PropertyGraph or
MGPropertyGraph.
Expand All @@ -83,13 +102,20 @@ def EXPERIMENTAL__to_pyg(G, backend="torch"):
----------
G : PropertyGraph or MGPropertyGraph
The graph to produce PyG wrappers for.
renumber_graph: bool
Should usually be set to True. If True, the vertices and edges
in the provided property graph will be renumbered so that they
are contiguous by type. If the vertices and edges are already
contiguously renumbered by type, then this can be set to False.

Returns
-------
Tuple (CuGraphStore, CuGraphStore)
Wrappers for the provided property graph.
"""
store = EXPERIMENTAL__CuGraphStore(G, backend=backend)
store = EXPERIMENTAL__CuGraphStore(
G, backend=backend, renumber_graph=renumber_graph
)
return (store, store)


Expand Down Expand Up @@ -178,7 +204,7 @@ class EXPERIMENTAL__CuGraphStore:
Duck-typed version of PyG's GraphStore and FeatureStore.
"""

def __init__(self, G, reserved_keys=[], backend="torch"):
def __init__(self, G, backend="torch", renumber_graph=None):
"""
Constructs a new CuGraphStore from the provided
arguments.
Expand All @@ -188,27 +214,35 @@ def __init__(self, G, reserved_keys=[], backend="torch"):
G : PropertyGraph or MGPropertyGraph
The cuGraph property graph where the
data is being stored.
reserved_keys : Properties in the graph that are not used for
training (the 'x' attribute will ignore these properties).
backend : The backend that manages tensors (default = 'torch')
backend : ('torch', 'cupy')
The backend that manages tensors (default = 'torch')
Should usually be 'torch' ('torch', 'cupy' supported).
renumber_graph : bool
If True, will renumber vertices and edges to have contiguous
ids per type. If False, will not renumber vertices. If not
specified, will renumber and raise a warning.
"""

# TODO ensure all x properties are float32 type
# TODO ensure y is of long type
# FIXME ensure all x properties are float32 type
# FIXME ensure y is of long type
if None in G.edge_types:
raise ValueError("Unspecified edge types not allowed in PyG")

# FIXME drop the cupy backend and remove these checks (#2995)
if backend == "torch":
from torch.utils.dlpack import from_dlpack
from torch import int64 as vertex_dtype
from torch import float32 as property_dtype
from torch import searchsorted as searchsorted
from torch import concatenate as concatenate
from torch import arange as arange
elif backend == "cupy":
from cupy import from_dlpack
from cupy import int64 as vertex_dtype
from cupy import float32 as property_dtype
from cupy import searchsorted as searchsorted
from cupy import concatenate as concatenate
from cupy import arange as arange
else:
raise ValueError(f"Invalid backend {backend}.")

Expand All @@ -217,19 +251,21 @@ def __init__(self, G, reserved_keys=[], backend="torch"):
self.vertex_dtype = vertex_dtype
self.property_dtype = property_dtype
self.searchsorted = searchsorted
self.concatenate = concatenate
self.arange = arange

self.__graph = G
self.__subgraphs = {}

self.__reserved_keys = [
self.__graph.type_col_name,
self.__graph.vertex_col_name,
] + list(reserved_keys)

self._tensor_attr_cls = CuGraphTensorAttr
self._tensor_attr_dict = defaultdict(list)
self.__infer_x_and_y_tensors()

# Must be called after __infer_x_and_y_tensors to
# avoid adding the old vertex id as a property when
# users do not specify it.
self.__renumber_graph(renumber_graph)

self.__edge_types_to_attrs = {}
for edge_type in self.__graph.edge_types:
edges = self.__graph.get_edge_data(types=[edge_type])
Expand Down Expand Up @@ -269,6 +305,92 @@ def __init__(self, G, reserved_keys=[], backend="torch"):

self._edge_attr_cls = CuGraphEdgeAttr

def __renumber_graph(self, renumber_graph):
"""
Renumbers the vertices and edges in this store's property graph
and sets the vertex offsets.
If renumber_graph is False, then renumber_vertices_by_type()
and renumber_edges_by_type()
are not called and the offsets are inferred from vertex counts.

If renumber_graph is None, it defaults to True, warns the
user of this default behavior, and saves the current ids as
<vertex_col>_old.

If renumber_graph is True, it calls renumber_vertices_by_type()
and renumber_edges_by_type(),
overwriting the current vertex and edge ids without saving them.
"""
self.__old_vertex_col_name = None
self.__old_edge_col_name = None

if renumber_graph is None:
renumber_graph = True
self.__old_vertex_col_name = f"{self.__graph.vertex_col_name}_old"
self.__old_edge_col_name = f"{self.__graph.edge_id_col_name}_old"
warnings.warn(
f"renumber_graph not specified; renumbering by default "
f"and saving as {self.__old_vertex_col_name} "
f"and {self.__old_edge_col_name}"
)

if renumber_graph:
if self.is_remote and self.backend == "torch":
self.__vertex_type_offsets = self.__graph.renumber_vertices_by_type(
prev_id_column=self.__old_vertex_col_name,
backend="torch:cuda" if torch.has_cuda else "torch",
)
else:
self.__vertex_type_offsets = self.__graph.renumber_vertices_by_type(
prev_id_column=self.__old_vertex_col_name
)

# FIXME: https://github.com/rapidsai/cugraph/issues/3059
# Currently renumbering edges is required if renumbering vertices or else
# there is a dask partitioning issue.
self.__graph.renumber_edges_by_type(prev_id_column=self.__old_edge_col_name)

else:
self.__vertex_type_offsets = {}
self.__vertex_type_offsets["stop"] = [
self.__graph.get_num_vertices(vt) for vt in self.__graph.vertex_types
]
if self.__backend == "cupy":
self.__vertex_type_offsets["stop"] = cupy.array(
self.__vertex_type_offsets["stop"]
)
else:
self.__vertex_type_offsets["stop"] = torch.tensor(
self.__vertex_type_offsets["stop"]
)
if torch.has_cuda:
self.__vertex_type_offsets["stop"] = self.__vertex_type_offsets[
"stop"
].cuda()

cumsum = self.__vertex_type_offsets["stop"].cumsum(0)
self.__vertex_type_offsets["start"] = (
self.__vertex_type_offsets["stop"] - cumsum
)
self.__vertex_type_offsets["stop"] -= 1
self.__vertex_type_offsets["type"] = np.array(self.__graph.vertex_types)

@property
def _old_vertex_col_name(self):
"""
Returns the name of the new property in the wrapped property graph where
the original vertex ids were stored, if this store did its own renumbering.
"""
return self.__old_vertex_col_name

@property
def _old_edge_col_name(self):
"""
Returns the name of the new property in the wrapped property graph where
the original edge ids were stored, if this store did its own renumbering.
"""
return self.__old_edge_col_name

@property
def _edge_types_to_attrs(self):
return dict(self.__edge_types_to_attrs)
Expand Down Expand Up @@ -297,18 +419,18 @@ def _is_delayed(self):
return self.is_multi_gpu and not self.is_remote

def get_vertex_index(self, vtypes):
# TODO force the graph to use offsets and
# return these values based on offsets

if isinstance(vtypes, str):
vtypes = [vtypes]

ix = self.__graph.get_vertex_data(
types=vtypes, columns=[self.__graph.type_col_name]
)[self.__graph.vertex_col_name]

if self._is_delayed:
ix = ix.compute()
# FIXME always use torch, drop cupy (#2995)
if self.__backend == "torch":
ix = torch.tensor()
else:
ix = cupy.array()
for vtype in vtypes:
start = self.__vertex_type_offsets["start"][vtype]
stop = self.__vertex_type_offsets["stop"][vtype]
ix = self.concatenate(ix, self.arange(start, stop + 1, 1))

return self.from_dlpack(ix.to_dlpack())

Expand Down Expand Up @@ -461,9 +583,10 @@ def _subgraph(self, edge_types):
edge_types = tuple(sorted(edge_types))

if edge_types not in self.__subgraphs:
query = f'(_TYPE_=="{edge_types[0]}")'
TCN = self.__graph.type_col_name
query = f'({TCN}=="{edge_types[0]}")'
for t in edge_types[1:]:
query += f' | (_TYPE_=="{t}")'
query += f' | ({TCN}=="{t}")'
selection = self.__graph.select_edges(query)

# FIXME enforce int type
Expand Down
5 changes: 2 additions & 3 deletions python/cugraph-pyg/cugraph_pyg/loader/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -20,8 +20,7 @@
from cugraph.utilities.utils import import_optional
except ModuleNotFoundError:
raise ModuleNotFoundError(
"cuGraph extensions for PyG require cuGraph"
"or cuGraph-Service to be installed."
"cuGraph-PyG requires cuGraph or cuGraph-Service to be installed."
alexbarghi-nv marked this conversation as resolved.
Show resolved Hide resolved
)

_transform_to_backend_dtype_1d = import_optional(
Expand Down
77 changes: 77 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile

import pytest

from dask.distributed import Client
from dask_cuda import LocalCUDACluster
from dask_cuda.initialize import initialize

from cugraph.dask.comms import comms as Comms
from cugraph.dask.common.mg_utils import get_visible_devices


# module-wide fixtures


# Spoof the gpubenchmark fixture if it's not available so that asvdb and
alexbarghi-nv marked this conversation as resolved.
Show resolved Hide resolved
# rapids-pytest-benchmark do not need to be installed to run tests.
if "gpubenchmark" not in globals():

def benchmark_func(func, *args, **kwargs):
return func(*args, **kwargs)

@pytest.fixture
def gpubenchmark():
return benchmark_func


@pytest.fixture(scope="module")
def dask_client():
dask_scheduler_file = os.environ.get("SCHEDULER_FILE")
cluster = None
client = None
tempdir_object = None

if dask_scheduler_file:
# Env var UCX_MAX_RNDV_RAILS=1 must be set too.
initialize(
alexbarghi-nv marked this conversation as resolved.
Show resolved Hide resolved
enable_tcp_over_ucx=True,
enable_nvlink=True,
enable_infiniband=True,
enable_rdmacm=True,
alexbarghi-nv marked this conversation as resolved.
Show resolved Hide resolved
# net_devices="mlx5_0:1",
)
client = Client(scheduler_file=dask_scheduler_file)
print("\ndask_client fixture: client created using " f"{dask_scheduler_file}")
else:
# The tempdir created by tempdir_object should be cleaned up once
# tempdir_object goes out-of-scope and is deleted.
tempdir_object = tempfile.TemporaryDirectory()
cluster = LocalCUDACluster(local_directory=tempdir_object.name)
client = Client(cluster)
client.wait_for_workers(len(get_visible_devices()))
print("\ndask_client fixture: client created using LocalCUDACluster")

Comms.initialize(p2p=True)

yield client

Comms.destroy()
client.close()
if cluster:
cluster.close()
print("\ndask_client fixture: client.close() called")
Loading