diff --git a/python/gaas_client/__init__.py b/python/gaas_client/__init__.py index 3068e07..68591c7 100644 --- a/python/gaas_client/__init__.py +++ b/python/gaas_client/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .client import GaasClient +from gaas_client.client import GaasClient diff --git a/python/gaas_client/client.py b/python/gaas_client/client.py index 631f23d..4f7af57 100644 --- a/python/gaas_client/client.py +++ b/python/gaas_client/client.py @@ -14,8 +14,8 @@ from functools import wraps -from . import defaults -from .gaas_thrift import create_client +from gaas_client import defaults +from gaas_client.gaas_thrift import create_client class GaasClient: @@ -76,12 +76,14 @@ def wrapped_method(self, *args, **kwargs): return ret_val return wrapped_method - def open(self): + def open(self, call_timeout=90000): """ Opens a connection to the server at self.host/self.port if one is not already established. close() must be called in order to allow other connections from other clients to be made. + This call does nothing if a connection to the server is already open. + Note: all APIs that access the server will call this method automatically, followed automatically by a call to close(), so calling this method should not be necessary. close() is not automatically called @@ -89,7 +91,9 @@ def open(self): Parameters ---------- - None + call_timeout : int (default is 90000) + Time in millisecods that calls to the server using this open + connection must return by. Returns ------- @@ -103,12 +107,15 @@ def open(self): >>> # clients cannot connect until a client API call completes or >>> # close() is manually called. >>> client.open() + """ if self.__client is None: - self.__client = create_client(self.host, self.port) + self.__client = create_client(self.host, self.port, + call_timeout=call_timeout) def close(self): - """Closes a connection to the server if one has been established, allowing + """ + Closes a connection to the server if one has been established, allowing other clients to access the server. This method is called automatically for all APIs that access the server if self.hold_open is False. @@ -138,6 +145,131 @@ def close(self): self.__client.close() self.__client = None + ############################################################################ + # Environment management + @__server_connection + def uptime(self): + """ + Return the server uptime in seconds. This is often used as a "ping". + + Parameters + ---------- + None + + Returns + ------- + uptime : int + The time in seconds the server has been running. + + Examples + -------- + >>> from gaas_client import GaasClient + >>> client = GaasClient() + >>> client.uptime() + >>> 32 + """ + return self.__client.uptime() + + @__server_connection + def load_graph_creation_extensions(self, extension_dir_path): + """ + Loads the extensions for graph creation present in the directory + specified by extension_dir_path. + + Parameters + ---------- + extension_dir_path : string + Path to the directory containing the extension files (.py source + files). This directory must be readable by the server. + + Returns + ------- + num_files_read : int + Number of extension files read in the extension_dir_path directory. + + Examples + -------- + >>> from gaas_client import GaasClient + >>> client = GaasClient() + >>> num_files_read = client.load_graph_creation_extensions( + ... "/some/server/side/directory") + >>> + """ + return self.__client.load_graph_creation_extensions(extension_dir_path) + + @__server_connection + def unload_graph_creation_extensions(self): + """ + Removes all extensions for graph creation previously loaded. + + Parameters + ---------- + None + + Returns + ------- + None + + Examples + -------- + >>> from gaas_client import GaasClient + >>> client = GaasClient() + >>> client.unload_graph_creation_extensions() + >>> + """ + return self.__client.unload_graph_creation_extensions() + + @__server_connection + def call_graph_creation_extension(self, func_name, + *func_args, **func_kwargs): + """ + Calls a graph creation extension on the server that was previously + loaded by a prior call to load_graph_creation_extensions(), then returns + the graph ID of the graph created by the extension. + + Parameters + ---------- + func_name : string + The name of the server-side extension function loaded by a prior + call to load_graph_creation_extensions(). All graph creation + extension functions are expected to return a new graph. + + *func_args : string, int, list, dictionary (optional) + The positional args to pass to func_name. Note that func_args are + converted to their string representation using repr() on the client, + then restored to python objects on the server using eval(), and + therefore only objects that can be restored server-side with eval() + are supported. + + **func_kwargs : string, int, list, dictionary + The keyword args to pass to func_name. Note that func_kwargs are + converted to their string representation using repr() on the client, + then restored to python objects on the server using eval(), and + therefore only objects that can be restored server-side with eval() + are supported. + + Returns + ------- + graph_id : int + unique graph ID + + Examples + -------- + >>> from gaas_client import GaasClient + >>> client = GaasClient() + >>> # Load the extension file containing "my_complex_create_graph()" + >>> client.load_graph_creation_extensions("/some/server/side/directory") + >>> new_graph_id = client.call_graph_creation_extension( + ... "my_complex_create_graph", + ... "/path/to/csv/on/server/graph.csv", + ... clean_data=True) + >>> + """ + func_args_repr = repr(func_args) + func_kwargs_repr = repr(func_kwargs) + return self.__client.call_graph_creation_extension( + func_name, func_args_repr, func_kwargs_repr) + ############################################################################ # Graph management @__server_connection @@ -460,8 +592,9 @@ def extract_subgraph(self, Examples -------- >>> - """ + # FIXME: finish docstring above + # FIXME: convert defaults to type needed by the Thrift API. These will # be changing to different types. create_using = create_using or "" @@ -504,6 +637,8 @@ def node2vec(self, start_vertices, max_depth, graph_id=defaults.graph_id): -------- >>> """ + # FIXME: finish docstring above + # start_vertices must be a list (cannot just be an iterable), and assume # return value is tuple of python lists on host. if not isinstance(start_vertices, list): diff --git a/python/gaas_client/defaults.py b/python/gaas_client/defaults.py index fe8e6bf..e84d52f 100644 --- a/python/gaas_client/defaults.py +++ b/python/gaas_client/defaults.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -host = "127.0.0.1" +host = "localhost" port = 9090 graph_id = 0 diff --git a/python/gaas_client/exceptions.py b/python/gaas_client/exceptions.py index a5906f4..12bad4f 100644 --- a/python/gaas_client/exceptions.py +++ b/python/gaas_client/exceptions.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .gaas_thrift import spec +from gaas_client.gaas_thrift import spec +# FIXME: add more fine-grained exceptions! GaasError = spec.GaasError diff --git a/python/gaas_client/gaas_thrift.py b/python/gaas_client/gaas_thrift.py index 9625793..ba9c25c 100644 --- a/python/gaas_client/gaas_thrift.py +++ b/python/gaas_client/gaas_thrift.py @@ -17,6 +17,7 @@ import thriftpy2 from thriftpy2.rpc import make_server, make_client + # This is the Thrift input file as a string rather than a separate file. This # allows the Thrift input to be contained within the module that's responsible # for all Thrift-specific details rather than a separate .thrift file. @@ -24,8 +25,8 @@ # thriftpy2 (https://github.com/Thriftpy/thriftpy2) is being used here instead # of Apache Thrift since it offers an easier-to-use API exclusively for Python # which is still compatible with servers/cleints using Apache Thrift (Apache -# Thrift can be used from a variety of different languages) and offers roughly -# the same performance. +# Thrift can be used from a variety of different languages) while offering +# approximately the same performance. # # See the Apache Thrift tutorial for Python for examples: # https://thrift.apache.org/tutorial/py.html @@ -43,6 +44,8 @@ service GaasService { + i32 uptime() + i32 create_graph() throws(1:GaasError e), void delete_graph(1:i32 graph_id) throws (1:GaasError e), @@ -84,6 +87,16 @@ 5:bool allow_multi_edges, 6:i32 graph_id ) throws (1:GaasError e), + + i32 load_graph_creation_extensions(1:string extension_dir_path + ) throws (1:GaasError e), + + void unload_graph_creation_extensions(), + + i32 call_graph_creation_extension(1:string func_name, + 2:string func_args_repr, + 3:string func_kwargs_repr + ) throws (1:GaasError e), } """ @@ -108,9 +121,30 @@ def create_server(handler, host, port): """ return make_server(spec.GaasService, handler, host, port) -def create_client(host, port): + +def create_client(host, port, call_timeout=90000): """ Return a client object that will make calls on a server listening on host/port. + + The call_timeout value defaults to 90 seconds, and is used for setting the + timeout for server API calls when using the client created here - if a call + does not return in call_timeout milliseconds, an exception is raised. """ - return make_client(spec.GaasService, host=host, port=port) + try: + return make_client(spec.GaasService, host=host, port=port, + timeout=call_timeout) + except thriftpy2.transport.TTransportException: + # Rasie a GaaS exception in order to completely encapsulate all Thrift + # details in this module. If this was not done, callers of this function + # would have to import thriftpy2 in order to catch the + # TTransportException, which then leaks thriftpy2. + # + # NOTE: normally the GaasError exception is imported from the + # gaas_client.exceptions module, but since + # gaas_client.exceptions.GaasError is actually defined from the spec in + # this module, just use it directly from spec. + # + # FIXME: this exception could use more detail + raise spec.GaasError("could not create a client session with a " + "GaaS server") diff --git a/python/gaas_client/types.py b/python/gaas_client/types.py index 826c73e..a2d8203 100644 --- a/python/gaas_client/types.py +++ b/python/gaas_client/types.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .gaas_thrift import spec +from gaas_client.gaas_thrift import spec Node2vecResult = spec.Node2vecResult diff --git a/python/gaas_server/gaas_handler.py b/python/gaas_server/gaas_handler.py new file mode 100644 index 0000000..dc792dd --- /dev/null +++ b/python/gaas_server/gaas_handler.py @@ -0,0 +1,299 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +import importlib +import time +import traceback + +import cudf +import cugraph +from cugraph.experimental import PropertyGraph + +from gaas_client import defaults +from gaas_client.exceptions import GaasError +from gaas_client.types import Node2vecResult + + +class GaasHandler: + """ + Class which handles RPC requests for a GaasService. + """ + def __init__(self): + self.__next_graph_id = defaults.graph_id + 1 + self.__graph_objs = {} + self.__graph_creation_extensions = {} + self.__start_time = int(time.time()) + + ############################################################################ + # Environment management + def uptime(self): + """ + Return the server uptime in seconds. This is often used as a "ping". + """ + return int(time.time()) - self.__start_time + + def load_graph_creation_extensions(self, extension_dir_path): + """ + Loads ("imports") all modules matching the pattern *_extension.py in the + directory specified by extension_dir_path. + + The modules are searched and their functions are called (if a match is + found) when call_graph_creation_extension() is called. + """ + extension_dir = Path(extension_dir_path) + + if (not extension_dir.exists()) or (not extension_dir.is_dir()): + raise GaasError(f"bad directory: {extension_dir}") + + num_files_read = 0 + + for ext_file in extension_dir.glob("*_extension.py"): + module_name = ext_file.stem + spec = importlib.util.spec_from_file_location(module_name, ext_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + self.__graph_creation_extensions[module_name] = module + num_files_read += 1 + + return num_files_read + + def unload_graph_creation_extensions(self): + """ + Removes all graph creation extensions. + """ + self.__graph_creation_extensions.clear() + + def call_graph_creation_extension(self, func_name, + func_args_repr, func_kwargs_repr): + """ + Calls the graph creation extension function func_name and passes it the + eval'd func_args_repr and func_kwargs_repr objects. + + The arg/kwarg reprs are eval'd prior to calling in order to pass actual + python objects to func_name (this is needed to allow arbitrary arg + objects to be serialized as part of the RPC call from the + client). + + func_name cannot be a private name (name starting with __). + + All loaded extension modules are checked when searching for func_name, + and the first extension module that contains it will have its function + called. + """ + if not(func_name.startswith("__")): + for module in self.__graph_creation_extensions.values(): + # Ignore private functions + func = getattr(module, func_name, None) + if func is not None: + func_args = eval(func_args_repr) + func_kwargs = eval(func_kwargs_repr) + try: + graph_obj = func(*func_args, **func_kwargs) + except Exception: + # FIXME: raise a more detailed error + raise GaasError(f"error running {func_name} : " + f"{traceback.format_exc()}") + return self.__add_graph(graph_obj) + + raise GaasError(f"{func_name} is not a graph creation extension") + + ############################################################################ + # Graph management + def create_graph(self): + """ + Create a new graph associated with a new unique graph ID, return the new + graph ID. + """ + pG = PropertyGraph() + return self.__add_graph(pG) + + def delete_graph(self, graph_id): + """ + Remove the graph identified by graph_id from the server. + """ + if self.__graph_objs.pop(graph_id, None) is None: + raise GaasError(f"invalid graph_id {graph_id}") + + def get_graph_ids(self): + """ + Returns a list of the graph IDs currently in use. + """ + return list(self.__graph_objs.keys()) + + def load_csv_as_vertex_data(self, + csv_file_name, + delimiter, + dtypes, + header, + vertex_col_name, + type_name, + property_columns, + graph_id + ): + """ + Given a CSV csv_file_name present on the server's file system, read it + and apply it as edge data to the graph specified by graph_id, or the + default graph if not specified. + """ + pG = self._get_graph(graph_id) + if header == -1: + header = "infer" + elif header == -2: + header = None + # FIXME: error check that file exists + # FIXME: error check that edgelist was read correctly + gdf = cudf.read_csv(csv_file_name, + delimiter=delimiter, + dtype=dtypes, + header=header) + pG.add_vertex_data(gdf, + type_name=type_name, + vertex_col_name=vertex_col_name, + property_columns=property_columns) + + def load_csv_as_edge_data(self, + csv_file_name, + delimiter, + dtypes, + header, + vertex_col_names, + type_name, + property_columns, + graph_id + ): + """ + Given a CSV csv_file_name present on the server's file system, read it + and apply it as vertex data to the graph specified by graph_id, or the + default graph if not specified. + """ + pG = self._get_graph(graph_id) + # FIXME: error check that file exists + # FIXME: error check that edgelist read correctly + if header == -1: + header = "infer" + elif header == -2: + header = None + gdf = cudf.read_csv(csv_file_name, + delimiter=delimiter, + dtype=dtypes, + header=header) + pG.add_edge_data(gdf, + type_name=type_name, + vertex_col_names=vertex_col_names, + property_columns=property_columns) + + def get_num_edges(self, graph_id): + """ + Return the number of edges for the graph specified by graph_id. + """ + pG = self._get_graph(graph_id) + # FIXME: ensure non-PropertyGraphs that compute num_edges differently + # work too. + return pG.num_edges + + def extract_subgraph(self, + create_using, + selection, + edge_weight_property, + default_edge_weight, + allow_multi_edges, + graph_id + ): + """ + Extract a subgraph, return a new graph ID + """ + pG = self._get_graph(graph_id) + if not(isinstance(pG, PropertyGraph)): + raise GaasError("extract_subgraph() can only be called on a graph " + "with properties.") + # Convert defaults needed for the Thrift API into defaults used by + # PropertyGraph.extract_subgraph() + create_using = create_using or cugraph.Graph + selection = selection or None + edge_weight_property = edge_weight_property or None + + G = pG.extract_subgraph(create_using, + selection, + edge_weight_property, + default_edge_weight, + allow_multi_edges) + + return self.__add_graph(G) + + ############################################################################ + # Algos + def node2vec(self, start_vertices, max_depth, graph_id): + """ + """ + # FIXME: finish docstring above + # FIXME: exception handling + G = self._get_graph(graph_id) + if isinstance(G, PropertyGraph): + raise GaasError("node2vec() cannot operate directly on a graph with" + " properties, call extract_subgraph() then call " + "node2vec() on the extracted subgraph instead.") + + # FIXME: this should not be needed, need to update cugraph.node2vec to + # also accept a list + start_vertices = cudf.Series(start_vertices, dtype="int32") + + (paths, weights, path_sizes) = \ + cugraph.node2vec(G, start_vertices, max_depth) + + node2vec_result = Node2vecResult( + vertex_paths = paths.to_arrow().to_pylist(), + edge_weights = weights.to_arrow().to_pylist(), + path_sizes = path_sizes.to_arrow().to_pylist() + ) + return node2vec_result + + def pagerank(self, graph_id): + """ + """ + raise NotImplementedError + + ############################################################################ + # "Protected" interface - used for both implementation and test/debug. Will + # not be exposed to a GaaS client. + def _get_graph(self, graph_id): + """ + Return the cuGraph Graph object (likely a PropertyGraph) associated with + graph_id. + + If the graph_id is the default graph ID and the default graph has not + been created, then instantiate a new PropertyGraph as the default graph + and return it. + """ + pG = self.__graph_objs.get(graph_id) + if pG is None: + # Always create the default graph if it does not exist + if graph_id == defaults.graph_id: + pG = PropertyGraph() + self.__graph_objs[graph_id] = pG + else: + raise GaasError(f"invalid graph_id {graph_id}") + return pG + + ############################################################################ + # Private + def __add_graph(self, G): + """ + Create a new graph ID for G and add G to the internal mapping of + graph ID:graph instance. + """ + gid = self.__next_graph_id + self.__graph_objs[gid] = G + self.__next_graph_id += 1 + return gid diff --git a/python/gaas_server/server.py b/python/gaas_server/server.py index 78abd62..c5c4975 100644 --- a/python/gaas_server/server.py +++ b/python/gaas_server/server.py @@ -12,197 +12,54 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cudf -import cugraph -from cugraph.experimental import PropertyGraph +import argparse +from pathlib import Path from gaas_client import defaults -from gaas_client.exceptions import GaasError -from gaas_client.types import Node2vecResult +from gaas_client.gaas_thrift import create_server +from gaas_server.gaas_handler import GaasHandler -class GaasHandler: +def create_handler(graph_creation_extension_dir=None): """ - Class which handles RPC requests for a GaasService. + Create and return a GaasHandler instance initialized with options. Setting + graph_creation_extension_dir to a valid dir results in the handler loading + graph creation extensions from that dir. """ - def __init__(self): - self.__next_graph_id = defaults.graph_id + 1 - self.__graph_objs = {} + handler = GaasHandler() + if graph_creation_extension_dir: + handler.load_graph_creation_extensions(graph_creation_extension_dir) + return handler - ############################################################################ - # Environment management - # FIXME: do we need environment mgmt functions? This could be things - # related to querying the state of the dask cluster, - # starting/stopping/restarting, etc. - - ############################################################################ - # Graph management - def create_graph(self): - """ - Create a new graph associated with a new unique graph ID, return the new - graph ID. - """ - pG = PropertyGraph() - return self.__add_graph(pG) - - def delete_graph(self, graph_id): - """ - Remove the graph identified by graph_id from the server. - """ - if self.__graph_objs.pop(graph_id, None) is None: - raise GaasError(f"invalid graph_id {graph_id}") - - def get_graph_ids(self): - return list(self.__graph_objs.keys()) - - def load_csv_as_vertex_data(self, - csv_file_name, - delimiter, - dtypes, - header, - vertex_col_name, - type_name, - property_columns, - graph_id - ): - pG = self.__get_graph(graph_id) - if header == -1: - header = "infer" - elif header == -2: - header = None - # FIXME: error check that file exists - # FIXME: error check that edgelist was read correctly - gdf = cudf.read_csv(csv_file_name, - delimiter=delimiter, - dtype=dtypes, - header=header) - pG.add_vertex_data(gdf, - type_name=type_name, - vertex_col_name=vertex_col_name, - property_columns=property_columns) - - def load_csv_as_edge_data(self, - csv_file_name, - delimiter, - dtypes, - header, - vertex_col_names, - type_name, - property_columns, - graph_id - ): - pG = self.__get_graph(graph_id) - # FIXME: error check that file exists - # FIXME: error check that edgelist read correctly - if header == -1: - header = "infer" - elif header == -2: - header = None - gdf = cudf.read_csv(csv_file_name, - delimiter=delimiter, - dtype=dtypes, - header=header) - pG.add_edge_data(gdf, - type_name=type_name, - vertex_col_names=vertex_col_names, - property_columns=property_columns) - - def get_num_edges(self, graph_id): - pG = self.__get_graph(graph_id) - # FIXME: ensure non-PropertyGraphs that compute num_edges differently - # work too. - return pG.num_edges - - def extract_subgraph(self, - create_using, - selection, - edge_weight_property, - default_edge_weight, - allow_multi_edges, - graph_id - ): - """ - Extract a subgraph, return a new graph ID - """ - pG = self.__get_graph(graph_id) - if not(isinstance(pG, PropertyGraph)): - raise GaasError("extract_subgraph() can only be called on a graph " - "with properties.") - # Convert defaults needed for the Thrift API into defaults used by - # PropertyGraph.extract_subgraph() - create_using = create_using or cugraph.Graph - selection = selection or None - edge_weight_property = edge_weight_property or None - - G = pG.extract_subgraph(create_using, - selection, - edge_weight_property, - default_edge_weight, - allow_multi_edges) - - return self.__add_graph(G) - - ############################################################################ - # Algos - def node2vec(self, start_vertices, max_depth, graph_id): - """ - """ - # FIXME: exception handling - G = self.__get_graph(graph_id) - if isinstance(G, PropertyGraph): - raise GaasError("node2vec() cannot operate directly on a graph with" - " properties, call extract_subgraph() then call " - "node2vec() on the extracted subgraph instead.") - - # FIXME: this should not be needed, need to update cugraph.node2vec to - # also accept a list - start_vertices = cudf.Series(start_vertices, dtype="int32") - - (paths, weights, path_sizes) = \ - cugraph.node2vec(G, start_vertices, max_depth) - - node2vec_result = Node2vecResult( - vertex_paths = paths.to_arrow().to_pylist(), - edge_weights = weights.to_arrow().to_pylist(), - path_sizes = path_sizes.to_arrow().to_pylist() - ) - return node2vec_result - - def pagerank(self, graph_id): - """ - """ - raise NotImplementedError - - ############################################################################ - # Private - def __add_graph(self, G): - gid = self.__next_graph_id - self.__graph_objs[gid] = G - self.__next_graph_id += 1 - return gid - - def __get_graph(self, graph_id): - pG = self.__graph_objs.get(graph_id) - if pG is None: - # Always create the default graph if it does not exist - if graph_id == defaults.graph_id: - pG = PropertyGraph() - self.__graph_objs[graph_id] = pG - else: - raise GaasError(f"invalid graph_id {graph_id}") - return pG - - - -if __name__ == '__main__': - from gaas_client.gaas_thrift import create_server - - # FIXME: add CLI options to set non-default host and port values, and - # possibly other options. - server = create_server(GaasHandler(), - host=defaults.host, - port=defaults.port) - print('Starting the server...') - server.serve() - print('done.') +def start_server_blocking(handler, host, port): + """ + Start the GaaS server on host/port, using handler as the request handler + instance. This call blocks indefinitely until Ctrl-C. + """ + server = create_server(handler, host=host, port=port) + server.serve() # blocks until Ctrl-C (kill -2) + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="GaaS - (cu)Graph as a Service") + arg_parser.add_argument("--host", + type=str, + default=defaults.host, + help="hostname the server should use, default " \ + f"is {defaults.host}") + arg_parser.add_argument("--port", + type=int, + default=defaults.port, + help="port the server should listen on, default " \ + f"is {defaults.port}") + arg_parser.add_argument("--graph-creation-extension-dir", + type=Path, + help="dir to load graph creation extension " \ + "functions from") + args = arg_parser.parse_args() + handler = create_handler(args.graph_creation_extension_dir) + print("Starting GaaS...", flush=True) + start_server_blocking(handler, args.host, args.port) + print("done.") diff --git a/python/tests/__init__.py b/python/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/tests/conftest.py b/python/tests/conftest.py new file mode 100644 index 0000000..4a6a8ba --- /dev/null +++ b/python/tests/conftest.py @@ -0,0 +1,94 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest + +graph_creation_extension1_file_contents = """ +import cudf +from cugraph.experimental import PropertyGraph + +def custom_graph_creation_function(): + edgelist = cudf.DataFrame(columns=['src', 'dst'], + data=[(0, 77), (1, 88), (2, 99)]) + pG = PropertyGraph() + pG.add_edge_data(edgelist, vertex_col_names=('src', 'dst')) + return pG +""" + +graph_creation_extension2_file_contents = """ +import cudf +from cugraph.experimental import PropertyGraph + +def __my_private_function(): + pass + +def my_graph_creation_function(arg1, arg2): + edgelist = cudf.DataFrame(columns=[arg1, arg2], data=[(0, 1), (88, 99)]) + pG = PropertyGraph() + pG.add_edge_data(edgelist, vertex_col_names=(arg1, arg2)) + return pG +""" + +graph_creation_extension_long_running_file_contents = """ +import time +import cudf +from cugraph.experimental import PropertyGraph + +def long_running_graph_creation_function(): + time.sleep(10) + pG = PropertyGraph() + return pG +""" + +@pytest.fixture(scope="module") +def graph_creation_extension1(): + with TemporaryDirectory() as tmp_extension_dir: + # write graph creation extension .py file + graph_creation_extension_file = open( + Path(tmp_extension_dir)/"custom_graph_creation_extension.py", + "w") + print(graph_creation_extension1_file_contents, + file=graph_creation_extension_file, + flush=True) + + yield tmp_extension_dir + +@pytest.fixture(scope="module") +def graph_creation_extension2(): + with TemporaryDirectory() as tmp_extension_dir: + # write graph creation extension .py file + graph_creation_extension_file = open( + Path(tmp_extension_dir)/"my_graph_creation_extension.py", + "w") + print(graph_creation_extension2_file_contents, + file=graph_creation_extension_file, + flush=True) + + yield tmp_extension_dir + +@pytest.fixture(scope="module") +def graph_creation_extension_long_running(): + with TemporaryDirectory() as tmp_extension_dir: + # write graph creation extension .py file + graph_creation_extension_file = open( + Path(tmp_extension_dir)/"long_running_graph_creation_extension.py", + "w") + print(graph_creation_extension_long_running_file_contents, + file=graph_creation_extension_file, + flush=True) + + yield tmp_extension_dir diff --git a/tests/demo1.py b/python/tests/demo1.py similarity index 100% rename from tests/demo1.py rename to python/tests/demo1.py diff --git a/tests/gen_demo_data.py b/python/tests/gen_demo_data.py similarity index 100% rename from tests/gen_demo_data.py rename to python/tests/gen_demo_data.py diff --git a/tests/karate.csv b/python/tests/karate.csv similarity index 100% rename from tests/karate.csv rename to python/tests/karate.csv diff --git a/python/tests/test_e2e.py b/python/tests/test_e2e.py new file mode 100644 index 0000000..717f2f7 --- /dev/null +++ b/python/tests/test_e2e.py @@ -0,0 +1,247 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import subprocess +from pathlib import Path +import time + +import pytest + +_this_dir = Path(__file__).parent + +_data = {"karate": {"csv_file_name": + (_this_dir/"karate.csv").absolute().as_posix(), + "dtypes": ["int32", "int32", "float32"], + "num_edges": 156, + }, + } + + +############################################################################### +## fixtures + +@pytest.fixture(scope="module") +def server(graph_creation_extension1): + """ + Start a GaaS server, stop it when done with the fixture. This also uses + graph_creation_extension1 to preload a graph creation extension. + """ + from gaas_server import server + from gaas_client import GaasClient + from gaas_client.exceptions import GaasError + + server_file = server.__file__ + server_process = None + host = "localhost" + port = 9090 + graph_creation_extension_dir = graph_creation_extension1 + client = GaasClient(host, port) + + # pytest will update sys.path based on the tests it discovers, and for this + # source tree, an entry for the parent of this "tests" directory will be + # added. The parent to this "tests" directory also allows imports to find + # the GaaS sources, so in oder to ensure the server that's started is also + # using the same sources, the PYTHONPATH env should be set to the sys.path + # being used in this process. + env_dict = os.environ.copy() + env_dict["PYTHONPATH"] = ":".join(sys.path) + + with subprocess.Popen( + [sys.executable, server_file, + "--host", host, + "--port", str(port), + "--graph-creation-extension-dir", graph_creation_extension_dir], + env=env_dict, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True) as server_process: + try: + print("\nLaunched GaaS server, waiting for it to start...", + end="", flush=True) + max_retries = 10 + retries = 0 + while retries < max_retries: + try: + client.uptime() + print("started.") + break + except GaasError: + time.sleep(1) + retries += 1 + if retries >= max_retries: + raise RuntimeError("error starting server") + except: + if server_process.poll() is None: + server_process.terminate() + raise + + # yield control to the tests + yield + + # tests are done, now stop the server + print("\nTerminating server...", end="", flush=True) + server_process.terminate() + print("done.", flush=True) + + +@pytest.fixture(scope="function") +def client(server): + from gaas_client import GaasClient, defaults + + client = GaasClient(defaults.host, defaults.port) + + for gid in client.get_graph_ids(): + client.delete_graph(gid) + + #client.unload_graph_creation_extensions() + + yield client + client.close() + + +@pytest.fixture(scope="function") +def client_with_csv_loaded(client): + test_data = _data["karate"] + client.load_csv_as_edge_data(test_data["csv_file_name"], + dtypes=test_data["dtypes"], + vertex_col_names=["0", "1"], + type_name="") + assert client.get_graph_ids() == [0] + return (client, test_data) + + +############################################################################### +## tests + +def test_get_num_edges_default_graph(client_with_csv_loaded): + (client, test_data) = client_with_csv_loaded + assert client.get_num_edges() == test_data["num_edges"] + +def test_load_csv_as_edge_data_nondefault_graph(client): + from gaas_client.exceptions import GaasError + + test_data = _data["karate"] + + with pytest.raises(GaasError): + client.load_csv_as_edge_data(test_data["csv_file_name"], + dtypes=test_data["dtypes"], + vertex_col_names=["0", "1"], + type_name="", + graph_id=9999) + +def test_get_num_edges_nondefault_graph(client_with_csv_loaded): + from gaas_client.exceptions import GaasError + + (client, test_data) = client_with_csv_loaded + with pytest.raises(GaasError): + client.get_num_edges(9999) + + new_graph_id = client.create_graph() + client.load_csv_as_edge_data(test_data["csv_file_name"], + dtypes=test_data["dtypes"], + vertex_col_names=["0", "1"], + type_name="", + graph_id=new_graph_id) + + assert client.get_num_edges() == test_data["num_edges"] + assert client.get_num_edges(new_graph_id) == test_data["num_edges"] + + +def test_node2vec(client_with_csv_loaded): + (client, test_data) = client_with_csv_loaded + extracted_gid = client.extract_subgraph() + start_vertices = 11 + max_depth = 2 + (vertex_paths, edge_weights, path_sizes) = \ + client.node2vec(start_vertices, max_depth, extracted_gid) + # FIXME: consider a more thorough test + assert isinstance(vertex_paths, list) and len(vertex_paths) + assert isinstance(edge_weights, list) and len(edge_weights) + assert isinstance(path_sizes, list) and len(path_sizes) + + +def test_extract_subgraph(client_with_csv_loaded): + (client, test_data) = client_with_csv_loaded + Gid = client.extract_subgraph(create_using=None, + selection=None, + edge_weight_property="2", + default_edge_weight=None, + allow_multi_edges=False) + # FIXME: consider a more thorough test + assert Gid in client.get_graph_ids() + + +def test_load_and_call_graph_creation_extension(client, + graph_creation_extension2): + """ + Tests calling a user-defined server-side graph creation extension from the + GaaS client. + """ + # The graph_creation_extension returns the tmp dir created which contains + # the extension + extension_dir = graph_creation_extension2 + + num_files_loaded = client.load_graph_creation_extensions(extension_dir) + assert num_files_loaded == 1 + + new_graph_ID = client.call_graph_creation_extension( + "my_graph_creation_function", "a", "b") + + assert new_graph_ID in client.get_graph_ids() + + # Inspect the PG and ensure it was created from my_graph_creation_function + # FIXME: add client APIs to allow for a more thorough test of the graph + assert client.get_num_edges(new_graph_ID) == 2 + + +def test_load_and_call_graph_creation_long_running_extension( + client, + graph_creation_extension_long_running): + """ + Tests calling a user-defined server-side graph creation extension from the + GaaS client. + """ + # The graph_creation_extension returns the tmp dir created which contains + # the extension + extension_dir = graph_creation_extension_long_running + + num_files_loaded = client.load_graph_creation_extensions(extension_dir) + assert num_files_loaded == 1 + + new_graph_ID = client.call_graph_creation_extension( + "long_running_graph_creation_function") + + assert new_graph_ID in client.get_graph_ids() + + # Inspect the PG and ensure it was created from my_graph_creation_function + # FIXME: add client APIs to allow for a more thorough test of the graph + assert client.get_num_edges(new_graph_ID) == 0 + + +def test_call_graph_creation_extension(client): + """ + Ensure the graph creation extension preloaded by the server fixture is + callable. + """ + new_graph_ID = client.call_graph_creation_extension( + "custom_graph_creation_function") + + assert new_graph_ID in client.get_graph_ids() + + # Inspect the PG and ensure it was created from + # custom_graph_creation_function + # FIXME: add client APIs to allow for a more thorough test of the graph + assert client.get_num_edges(new_graph_ID) == 3 diff --git a/python/tests/test_gaas_handler.py b/python/tests/test_gaas_handler.py new file mode 100644 index 0000000..a5b7cbd --- /dev/null +++ b/python/tests/test_gaas_handler.py @@ -0,0 +1,117 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + + +############################################################################### +## fixtures +# The fixtures used in these tets are defined in conftest.py + + +############################################################################### +## tests + +def test_load_and_call_graph_creation_extension(graph_creation_extension2): + """ + Ensures load_extensions reads the extensions and makes the new APIs they add + available. + """ + from gaas_server.gaas_handler import GaasHandler + from gaas_client.exceptions import GaasError + + handler = GaasHandler() + + extension_dir = graph_creation_extension2 + + # DNE + with pytest.raises(GaasError): + handler.load_graph_creation_extensions("/path/that/does/not/exist") + + # Exists, but is a file + with pytest.raises(GaasError): + handler.load_graph_creation_extensions(__file__) + + # Load the extension and call the function defined in it + num_files_read = handler.load_graph_creation_extensions(extension_dir) + assert num_files_read == 1 + + # Private function should not be callable + with pytest.raises(GaasError): + handler.call_graph_creation_extension("__my_private_function", + "()", "{}") + + # Function which DNE in the extension + with pytest.raises(GaasError): + handler.call_graph_creation_extension("bad_function_name", + "()", "{}") + + # Wrong number of args + with pytest.raises(GaasError): + handler.call_graph_creation_extension("my_graph_creation_function", + "('a',)", "{}") + + # This call should succeed and should result in a new PropertyGraph present + # in the handler instance. + new_graph_ID = handler.call_graph_creation_extension( + "my_graph_creation_function", "('a', 'b')", "{}") + + assert new_graph_ID in handler.get_graph_ids() + + # Inspect the PG and ensure it was created from my_graph_creation_function + pG = handler._get_graph(new_graph_ID) + edge_props = pG.edge_property_names + assert ("a" in edge_props) and ("b" in edge_props) + + +def test_load_and_unload_graph_creation_extension(graph_creation_extension2): + """ + Ensure extensions can be unloaded. + """ + from gaas_server.gaas_handler import GaasHandler + from gaas_client.exceptions import GaasError + + handler = GaasHandler() + + extension_dir = graph_creation_extension2 + + # Load the extensions and ensure it can be called. + handler.load_graph_creation_extensions(extension_dir) + new_graph_ID = handler.call_graph_creation_extension( + "my_graph_creation_function", "('a', 'b')", "{}") + assert new_graph_ID in handler.get_graph_ids() + + # Unload then try to run the same call again, which should fail + handler.unload_graph_creation_extensions() + + with pytest.raises(GaasError): + handler.call_graph_creation_extension( + "my_graph_creation_function", "('a', 'b')", "{}") + + +def test_load_and_unload_graph_creation_extension_no_args( + graph_creation_extension1): + """ + Test graph_creation_extension1 which contains an extension with no args. + """ + from gaas_server.gaas_handler import GaasHandler + handler = GaasHandler() + + extension_dir = graph_creation_extension1 + + # Load the extensions and ensure it can be called. + handler.load_graph_creation_extensions(extension_dir) + new_graph_ID = handler.call_graph_creation_extension( + "custom_graph_creation_function", "()", "{}") + assert new_graph_ID in handler.get_graph_ids() diff --git a/tests/test_client.py b/tests/test_client.py deleted file mode 100644 index 2210367..0000000 --- a/tests/test_client.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pathlib import Path - -import pytest - -_this_dir = Path(__file__).parent - -_data = {"karate": {"csv_file_name": - (_this_dir/"karate.csv").absolute().as_posix(), - "dtypes": ["int32", "int32", "float32"], - "num_edges": 156, - }, - } - - -############################################################################### -## fixtures - -@pytest.fixture -def client(): - from gaas_client import GaasClient, defaults - - client = GaasClient(defaults.host, defaults.port) - # FIXME: this ensures a server that was running from a previous test is - # empty. Consider a different way to test using a new server instance. - for gid in client.get_graph_ids(): - client.delete_graph(gid) - - yield client - client.close() - - -@pytest.fixture -def client_with_csv_loaded(client): - test_data = _data["karate"] - client.load_csv_as_edge_data(test_data["csv_file_name"], - dtypes=test_data["dtypes"], - vertex_col_names=["0", "1"], - type_name="") - assert client.get_graph_ids() == [0] - return (client, test_data) - - -############################################################################### -## tests - -def test_get_num_edges_default_graph(client_with_csv_loaded): - (client, test_data) = client_with_csv_loaded - assert client.get_num_edges() == test_data["num_edges"] - -def test_load_csv_as_edge_data_nondefault_graph(client): - from gaas_client.exceptions import GaasError - - test_data = _data["karate"] - - with pytest.raises(GaasError): - client.load_csv_as_edge_data(test_data["csv_file_name"], - dtypes=test_data["dtypes"], - vertex_col_names=["0", "1"], - type_name="", - graph_id=9999) - -def test_get_num_edges_nondefault_graph(client_with_csv_loaded): - from gaas_client.exceptions import GaasError - - (client, test_data) = client_with_csv_loaded - with pytest.raises(GaasError): - client.get_num_edges(9999) - - new_graph_id = client.create_graph() - client.load_csv_as_edge_data(test_data["csv_file_name"], - dtypes=test_data["dtypes"], - vertex_col_names=["0", "1"], - type_name="", - graph_id=new_graph_id) - - assert client.get_num_edges() == test_data["num_edges"] - assert client.get_num_edges(new_graph_id) == test_data["num_edges"] - - -def test_node2vec(client_with_csv_loaded): - (client, test_data) = client_with_csv_loaded - extracted_gid = client.extract_subgraph() - start_vertices = 11 - max_depth = 2 - (vertex_paths, edge_weights, path_sizes) = \ - client.node2vec(start_vertices, max_depth, extracted_gid) - # FIXME: consider a more thorough test - assert isinstance(vertex_paths, list) and len(vertex_paths) - assert isinstance(edge_weights, list) and len(edge_weights) - assert isinstance(path_sizes, list) and len(path_sizes) - - -def test_extract_subgraph(client_with_csv_loaded): - (client, test_data) = client_with_csv_loaded - Gid = client.extract_subgraph(create_using=None, - selection=None, - edge_weight_property="2", - default_edge_weight=None, - allow_multi_edges=False) - # FIXME: consider a more thorough test - assert Gid in client.get_graph_ids()