diff --git a/python/cugraph/cugraph/dask/comms/comms.py b/python/cugraph/cugraph/dask/comms/comms.py index 5499b13af03..1e1c28fbbee 100644 --- a/python/cugraph/cugraph/dask/comms/comms.py +++ b/python/cugraph/cugraph/dask/comms/comms.py @@ -146,8 +146,6 @@ def initialize(comms=None, p2p=False, prows=None, pcols=None, partition_type=1): __default_handle = None if comms is None: # Initialize communicator - if not p2p: - raise Exception("Set p2p to True for running mnmg algorithms") __instance = raftComms(comms_p2p=p2p) __instance.init() # Initialize subcommunicator diff --git a/python/cugraph/cugraph/testing/mg_utils.py b/python/cugraph/cugraph/testing/mg_utils.py index 32854652f05..07399b90627 100644 --- a/python/cugraph/cugraph/testing/mg_utils.py +++ b/python/cugraph/cugraph/testing/mg_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -35,6 +35,7 @@ def start_dask_client( jit_unspill=False, worker_class=None, device_memory_limit=0.8, + p2p=True, ): """ Creates a new dask client, and possibly also a cluster, and returns them as @@ -95,6 +96,9 @@ def start_dask_client( dask_cuda.LocalCUDACluster for details. This parameter is ignored if the env var SCHEDULER_FILE is set which implies the dask cluster has already been created. + + p2p : bool, optional (default=True) + Initialize UCX endpoints if True. """ dask_scheduler_file = os.environ.get("SCHEDULER_FILE") dask_local_directory = os.getenv("DASK_LOCAL_DIRECTORY") @@ -164,7 +168,7 @@ def start_dask_client( # FIXME: use proper logging, INFO or DEBUG level print("\nDask client/cluster created using LocalCUDACluster") - Comms.initialize(p2p=True) + Comms.initialize(p2p=p2p) return (client, cluster) diff --git a/python/cugraph/cugraph/tests/conftest.py b/python/cugraph/cugraph/tests/conftest.py index cb5755128eb..d31c2968afe 100644 --- a/python/cugraph/cugraph/tests/conftest.py +++ b/python/cugraph/cugraph/tests/conftest.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# Copyright (c) 2021-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -52,6 +52,21 @@ def dask_client(): stop_dask_client(dask_client, dask_cluster) +# FIXME: Add tests leveraging this fixture +@pytest.fixture(scope="module") +def dask_client_non_p2p(): + # start_dask_client will check for the SCHEDULER_FILE and + # DASK_WORKER_DEVICES env vars and use them when creating a client if + # set. start_dask_client will also initialize the Comms singleton. + dask_client, dask_cluster = start_dask_client( + worker_class=IncreasedCloseTimeoutNanny, p2p=False + ) + + yield dask_client + + stop_dask_client(dask_client, dask_cluster) + + @pytest.fixture(scope="module") def scratch_dir(): # This should always be set if doing MG testing, since temporary