Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds arbitrary server extension support to cugraph-service #2850

Merged
20 changes: 20 additions & 0 deletions python/cugraph_service/cugraph_service_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# constants used by both client and server
# (the server package depends on the client so server code can share client
# code/utilities/defaults/etc.)
supported_extension_return_dtypes = [
"NoneType",
"int8",
"int16",
"int32",
"int64",
"float16",
"float32",
"float64",
]
# make a bi-directional mapping between type strings and ints. This is used for
# sending dtype meta-data between client and server.
extension_return_dtype_map = dict(enumerate(supported_extension_return_dtypes))
extension_return_dtype_map.update(
dict(map(reversed, extension_return_dtype_map.items()))
)

from cugraph_service_client.client import CugraphServiceClient
from cugraph_service_client.remote_graph import RemoteGraph
from cugraph_service_client.remote_graph import RemotePropertyGraph
249 changes: 223 additions & 26 deletions python/cugraph_service/cugraph_service_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import cupy as cp

from cugraph_service_client import defaults
from cugraph_service_client import extension_return_dtype_map
from cugraph_service_client.remote_graph import RemotePropertyGraph
from cugraph_service_client.types import (
ValueWrapper,
Expand Down Expand Up @@ -250,27 +251,64 @@ def load_graph_creation_extensions(self, extension_dir_path):

Returns
-------
num_files_read : int
Number of extension files read in the extension_dir_path directory.
extension_modnames : list
List of the module names loaded. These can be used in calls to
unload_extension_module()

Examples
--------
>>> from cugraph_service_client import CugraphServiceClient
>>> client = CugraphServiceClient()
>>> num_files_read = client.load_graph_creation_extensions(
>>> extension_modnames = client.load_graph_creation_extensions(
... "/some/server/side/directory")
>>>
"""
return self.__client.load_graph_creation_extensions(extension_dir_path)

@__server_connection
def unload_graph_creation_extensions(self):
def load_extensions(self, extension_dir_or_mod_path):
"""
Removes all extensions for graph creation previously loaded.
Loads the extensions present in the directory (path on disk), or module or
package path (as used in an import statement) specified by
extension_dir_or_mod_path.

Parameters
----------
None
extension_dir_or_mod_path : string
Path to the directory containing the extension files (.py source
files), or an importable module or package path (eg. my.package or
my.package.module). If a directory is specified it must be readable
by the server, and if a module or package path is specified it must
be importable by the server (ie. present in the sys.path of the
running server).

Returns
-------
extension_modnames : list
List of the module names loaded as paths to files on disk. These can
be used in calls to unload_extension_module()

Examples
--------
>>> from cugraph_service_client import CugraphServiceClient
>>> client = CugraphServiceClient()
>>> extension_modnames = client.load_graph_creation_extensions(
... "/some/server/side/directory")
>>> more_extension_modnames = client.load_graph_creation_extensions(
... "my_project.extensions.etl")
"""
return self.__client.load_extensions(extension_dir_or_mod_path)

@__server_connection
def unload_extension_module(self, modname):
"""
Removes all extensions contained in the modname module.

Parameters
----------
modname : string
Name of the module to be unloaded. All extension functions contained in
modname will no longer be callable.

Returns
-------
Expand All @@ -280,10 +318,12 @@ def unload_graph_creation_extensions(self):
--------
>>> from cugraph_service_client import CugraphServiceClient
>>> client = CugraphServiceClient()
>>> client.unload_graph_creation_extensions()
>>> ext_mod_name = client.load_graph_creation_extensions(
... "/some/server/side/directory")
>>> client.unload_extension_module(ext_mod_name)
>>>
"""
return self.__client.unload_graph_creation_extensions()
return self.__client.unload_extension_module(modname)

@__server_connection
def call_graph_creation_extension(self, func_name, *func_args, **func_kwargs):
Expand Down Expand Up @@ -336,6 +376,84 @@ def call_graph_creation_extension(self, func_name, *func_args, **func_kwargs):
func_name, func_args_repr, func_kwargs_repr
)

@__server_connection
def call_extension(
self,
func_name,
*func_args,
result_device=None,
**func_kwargs,
):
"""
Calls an extension on the server that was previously loaded by a prior
call to load_extensions(), then returns the result returned by the
extension.

Parameters
----------
func_name : string
The name of the server-side extension function loaded by a prior
call to load_graph_creation_extensions(). All graph creation
extension functions are expected to return a new graph.

*func_args : string, int, list, dictionary (optional)
The positional args to pass to func_name. Note that func_args are
converted to their string representation using repr() on the
client, then restored to python objects on the server using eval(),
and therefore only objects that can be restored server-side with
eval() are supported.

**func_kwargs : string, int, list, dictionary The keyword args to pass
to func_name. func_kwargs are converted to their string
representation using repr() on the client, then restored to python
objects on the server using eval(), and therefore only objects that
can be restored server-side with eval() are supported.

result_device is reserved for use in specifying an optional GPU
device ID to have the server transfer results to.

result_device : int, default is None
If specified, must be the integer ID of a GPU device to have the
server transfer results to as one or more cupy ndarrays

Returns
-------
result : python int, float, string, list
The result returned by the extension

Examples
--------
>>> from cugraph_service_client import CugraphServiceClient
>>> client = CugraphServiceClient()
>>> # Load the extension file containing "my_serverside_function()"
>>> client.load_extensions("/some/server/side/dir")
>>> result = client.call_extension(
... "my_serverside_function", 33, 22, "some_string")
>>>
"""
func_args_repr = repr(func_args)
func_kwargs_repr = repr(func_kwargs)
if result_device is not None:
result_obj = asyncio.run(
self.__call_extension_to_device(
func_name, func_args_repr, func_kwargs_repr, result_device
)
)
# result_obj is a cupy array or tuple of cupy arrays on result_device
return result_obj
else:
result_obj = self.__client.call_extension(
func_name,
func_args_repr,
func_kwargs_repr,
client_host=None,
client_result_port=None,
)
# Convert the structure returned from the RPC call to a python type
# FIXME: ValueWrapper ctor and get_py_obj are recursive and could be slow,
# especially if Value is a list. Consider returning the Value obj as-is.
return ValueWrapper(result_obj).get_py_obj()

###########################################################################
# Graph management
@__server_connection
Expand Down Expand Up @@ -998,16 +1116,6 @@ def batched_ego_graphs(self, seeds, radius=1, graph_id=defaults.graph_id):
seeds, radius, graph_id
)

# FIXME: ensure dtypes are correct for values returned from
# cugraph.batched_ego_graphs() in cugraph_handler.py
# return (numpy.frombuffer(batched_ego_graphs_result.src_verts,
# dtype="int32"),
# numpy.frombuffer(batched_ego_graphs_result.dst_verts,
# dtype="int32"),
# numpy.frombuffer(batched_ego_graphs_result.edge_weights,
# dtype="float64"),
# numpy.frombuffer(batched_ego_graphs_result.seeds_offsets,
# dtype="int64"))
return (
batched_ego_graphs_result.src_verts,
batched_ego_graphs_result.dst_verts,
Expand Down Expand Up @@ -1067,15 +1175,15 @@ def uniform_neighbor_sample(
Samples the graph and returns a UniformNeighborSampleResult instance.

Parameters:
start_list: list[int]
start_list : list[int]

fanout_vals: list[int]
fanout_vals : list[int]

with_replacement: bool
with_replacement : bool

graph_id: int, default is defaults.graph_id
graph_id : int, default is defaults.graph_id

result_device: int, default is None
result_device : int, default is None

Returns
-------
Expand Down Expand Up @@ -1215,7 +1323,16 @@ async def receiver(endpoint):

listener = ucp.create_listener(receiver, self.results_port)

uns_thread = threading.Thread(
# Use an excepthook to store an exception on the thread object if one is
# raised in the thread.
def excepthook(exc):
if exc.thread is not None:
exc.thread.exception = exc.exc_type(exc.exc_value)

orig_excepthook = threading.excepthook
threading.excepthook = excepthook

thread = threading.Thread(
target=self.__client.uniform_neighbor_sample,
args=(
start_list,
Expand All @@ -1226,14 +1343,94 @@ async def receiver(endpoint):
self.results_port,
),
)
uns_thread.start()
thread.start()

# Poll the listener and the state of the thread. Close the listener if
# the thread died and raise the stored exception.
while not listener.closed():
await asyncio.sleep(0.05)
if not thread.is_alive():
listener.close()
threading.excepthook = orig_excepthook
if hasattr(thread, "exception"):
raise thread.exception

uns_thread.join()
thread.join()
return result_obj

async def __call_extension_to_device(
self, func_name, func_args_repr, func_kwargs_repr, result_device
):
"""
Run the server-side extension func_name with the args/kwargs and have the
result sent directly to the device specified by result_device.
"""
# FIXME: there's probably a better way to do this, eg. create a class containing
# both allocator and receiver that maintains results, devices, etc. that's
# callable from the listener
result = []

# FIXME: check for valid device
allocator = DeviceArrayAllocator(result_device)

async def receiver(endpoint):
# Format of data sent is assumed to be:
# 1) a single array of length n describing the dtypes for the n arrays that
# follow
# 2) n arrays
with cp.cuda.Device(result_device):
# First get the array describing the data
# FIXME: meta_data doesn't need to be a cupy array
dtype_meta_data = await endpoint.recv_obj(allocator=allocator)
for dtype_enum in [int(i) for i in dtype_meta_data]:
# FIXME: safe to assume dtype_enum will always be valid?
dtype = extension_return_dtype_map[dtype_enum]
a = await endpoint.recv_obj(allocator=allocator)
result.append(a.view(dtype))

await endpoint.close()
listener.close()

listener = ucp.create_listener(receiver, self.results_port)

# Use an excepthook to store an exception on the thread object if one is
# raised in the thread.
def excepthook(exc):
if exc.thread is not None:
exc.thread.exception = exc.exc_type(exc.exc_value)

orig_excepthook = threading.excepthook
threading.excepthook = excepthook

thread = threading.Thread(
target=self.__client.call_extension,
args=(
func_name,
func_args_repr,
func_kwargs_repr,
self.host,
self.results_port,
),
)
thread.start()

# Poll the listener and the state of the thread. Close the listener if
# the thread died and raise the stored exception.
while not listener.closed():
await asyncio.sleep(0.05)
if not thread.is_alive():
listener.close()
threading.excepthook = orig_excepthook
if hasattr(thread, "exception"):
raise thread.exception

thread.join()

# special case, assume a list of len 1 should not be a list
if len(result) == 1:
result = result[0]
return result

@staticmethod
def __get_vertex_edge_id_obj(id_or_ids):
# FIXME: do not assume all values are int32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
2:i64 int64_value
3:string string_value
4:bool bool_value
5:double double_value
6:list<Value> list_value
}

service CugraphService {
Expand All @@ -87,16 +89,25 @@

map<string, Value> get_server_info() throws (1:CugraphServiceError e),

i32 load_graph_creation_extensions(1:string extension_dir_path
) throws (1:CugraphServiceError e),
list<string> load_graph_creation_extensions(1:string extension_dir_path
) throws (1:CugraphServiceError e),

void unload_graph_creation_extensions(),
list<string> load_extensions(1:string extension_dir_path
) throws (1:CugraphServiceError e),

void unload_extension_module(1:string modname) throws (1:CugraphServiceError e),

i32 call_graph_creation_extension(1:string func_name,
2:string func_args_repr,
3:string func_kwargs_repr
) throws (1:CugraphServiceError e),

Value call_extension(1:string func_name,
2:string func_args_repr,
3:string func_kwargs_repr
4:string result_host,
5:i16 result_port
) throws (1:CugraphServiceError e),

##############################################################################
# Graph management
Expand Down
Loading