From 2518bd248c050c5d472f80940e43eab3a85ad29a Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 26 Jun 2024 19:57:15 +0800 Subject: [PATCH 1/4] Basic distributed test for external memory. --- python-package/xgboost/testing/__init__.py | 5 +- tests/python/test_collective.py | 80 +++++++++++++++++++++- 2 files changed, 82 insertions(+), 3 deletions(-) diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 482da68c9fc9..af39e5602c94 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -255,6 +255,7 @@ def make_batches( use_cupy: bool = False, *, vary_size: bool = False, + random_state: int = 1994, ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: X = [] y = [] @@ -262,9 +263,9 @@ def make_batches( if use_cupy: import cupy - rng = cupy.random.RandomState(1994) + rng = cupy.random.RandomState(random_state) else: - rng = np.random.RandomState(1994) + rng = np.random.RandomState(random_state) for i in range(n_batches): n_samples = n_samples_per_batch + i * 10 if vary_size else n_samples_per_batch _X = rng.randn(n_samples, n_features) diff --git a/tests/python/test_collective.py b/tests/python/test_collective.py index 2beedf8a1caf..58e952818d3a 100644 --- a/tests/python/test_collective.py +++ b/tests/python/test_collective.py @@ -5,10 +5,10 @@ import numpy as np import pytest - import xgboost as xgb from xgboost import RabitTracker, build_info, federated from xgboost import testing as tm +from xgboost.compat import concat def run_rabit_worker(rabit_env, world_size): @@ -79,3 +79,81 @@ def test_federated_communicator(): for worker in workers: worker.join() assert worker.exitcode == 0 + + +def run_external_memory(worker_id: int, world_size: int, kwargs: dict) -> None: + n_samples_per_batch = 32 + n_features = 4 + n_batches = 16 + use_cupy = False + + n_cpus = multiprocessing.cpu_count() + with xgb.collective.CommunicatorContext(dmlc_communicator="rabit", **kwargs): + it = tm.IteratorForTest( + *tm.make_batches( + n_samples_per_batch, + n_features, + n_batches, + use_cupy, + random_state=worker_id, + ), + cache="cache", + ) + Xy = xgb.DMatrix(it, nthread=n_cpus // world_size) + results = {} + booster = xgb.train( + {"tree_method": "hist", "nthread": n_cpus // world_size}, + Xy, + evals=[(Xy, "Train")], + num_boost_round=32, + evals_result=results, + ) + assert tm.non_increasing(results["Train"]["rmse"]) + + lx, ly, lw = [], [], [] + for i in range(world_size): + x, y, w = tm.make_batches( + n_samples_per_batch, + n_features, + n_batches, + use_cupy, + random_state=i, + ) + lx.extend(x) + ly.extend(y) + lw.extend(w) + + X = concat(lx) + yconcat = concat(ly) + wconcat = concat(lw) + Xy = xgb.DMatrix(X, yconcat, wconcat, nthread=n_cpus // world_size) + + results_local = {} + booster = xgb.train( + {"tree_method": "hist", "nthread": n_cpus // world_size}, + Xy, + evals=[(Xy, "Train")], + num_boost_round=32, + evals_result=results_local, + ) + np.testing.assert_allclose(results["Train"]["rmse"], results_local["Train"]["rmse"], rtol=1e-4) + + +def test_external_memory() -> None: + world_size = 3 + + tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size) + tracker.start() + args = tracker.worker_args() + workers = [] + + for rank in range(world_size): + worker = multiprocessing.Process( + target=run_external_memory, args=(rank, world_size, args) + ) + worker.start() + workers.append(worker) + + for worker in workers: + worker.join() + assert worker.exitcode == 0 From d6f5eee780b904bb3eedb2d2fd10d7d788cb88e1 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 5 Jul 2024 17:57:50 +0800 Subject: [PATCH 2/4] threads. --- tests/python/test_collective.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/test_collective.py b/tests/python/test_collective.py index 58e952818d3a..8021cea8412c 100644 --- a/tests/python/test_collective.py +++ b/tests/python/test_collective.py @@ -99,7 +99,7 @@ def run_external_memory(worker_id: int, world_size: int, kwargs: dict) -> None: ), cache="cache", ) - Xy = xgb.DMatrix(it, nthread=n_cpus // world_size) + Xy = xgb.DMatrix(it, nthread=max(n_cpus // world_size, 1)) results = {} booster = xgb.train( {"tree_method": "hist", "nthread": n_cpus // world_size}, @@ -126,7 +126,7 @@ def run_external_memory(worker_id: int, world_size: int, kwargs: dict) -> None: X = concat(lx) yconcat = concat(ly) wconcat = concat(lw) - Xy = xgb.DMatrix(X, yconcat, wconcat, nthread=n_cpus // world_size) + Xy = xgb.DMatrix(X, yconcat, wconcat, nthread=max(n_cpus // world_size, 1)) results_local = {} booster = xgb.train( From 88739b3d5120e49f3555d680fe1cabeb0e408b7b Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 5 Jul 2024 18:01:26 +0800 Subject: [PATCH 3/4] Change into dask test. --- tests/ci_build/lint_python.py | 1 + tests/python/test_collective.py | 80 +---------------- .../test_with_dask/test_external_memory.py | 88 +++++++++++++++++++ 3 files changed, 90 insertions(+), 79 deletions(-) create mode 100644 tests/test_distributed/test_with_dask/test_external_memory.py diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index 079996de66fb..f8bbbc2848b0 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -98,6 +98,7 @@ class LintersPaths: "tests/python/test_model_io.py", "tests/test_distributed/test_federated/", "tests/test_distributed/test_gpu_federated/", + "tests/test_distributed/test_with_dask/test_external_memory.py", "tests/test_distributed/test_with_spark/test_data.py", "tests/test_distributed/test_gpu_with_spark/test_data.py", "tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py", diff --git a/tests/python/test_collective.py b/tests/python/test_collective.py index 8021cea8412c..2beedf8a1caf 100644 --- a/tests/python/test_collective.py +++ b/tests/python/test_collective.py @@ -5,10 +5,10 @@ import numpy as np import pytest + import xgboost as xgb from xgboost import RabitTracker, build_info, federated from xgboost import testing as tm -from xgboost.compat import concat def run_rabit_worker(rabit_env, world_size): @@ -79,81 +79,3 @@ def test_federated_communicator(): for worker in workers: worker.join() assert worker.exitcode == 0 - - -def run_external_memory(worker_id: int, world_size: int, kwargs: dict) -> None: - n_samples_per_batch = 32 - n_features = 4 - n_batches = 16 - use_cupy = False - - n_cpus = multiprocessing.cpu_count() - with xgb.collective.CommunicatorContext(dmlc_communicator="rabit", **kwargs): - it = tm.IteratorForTest( - *tm.make_batches( - n_samples_per_batch, - n_features, - n_batches, - use_cupy, - random_state=worker_id, - ), - cache="cache", - ) - Xy = xgb.DMatrix(it, nthread=max(n_cpus // world_size, 1)) - results = {} - booster = xgb.train( - {"tree_method": "hist", "nthread": n_cpus // world_size}, - Xy, - evals=[(Xy, "Train")], - num_boost_round=32, - evals_result=results, - ) - assert tm.non_increasing(results["Train"]["rmse"]) - - lx, ly, lw = [], [], [] - for i in range(world_size): - x, y, w = tm.make_batches( - n_samples_per_batch, - n_features, - n_batches, - use_cupy, - random_state=i, - ) - lx.extend(x) - ly.extend(y) - lw.extend(w) - - X = concat(lx) - yconcat = concat(ly) - wconcat = concat(lw) - Xy = xgb.DMatrix(X, yconcat, wconcat, nthread=max(n_cpus // world_size, 1)) - - results_local = {} - booster = xgb.train( - {"tree_method": "hist", "nthread": n_cpus // world_size}, - Xy, - evals=[(Xy, "Train")], - num_boost_round=32, - evals_result=results_local, - ) - np.testing.assert_allclose(results["Train"]["rmse"], results_local["Train"]["rmse"], rtol=1e-4) - - -def test_external_memory() -> None: - world_size = 3 - - tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size) - tracker.start() - args = tracker.worker_args() - workers = [] - - for rank in range(world_size): - worker = multiprocessing.Process( - target=run_external_memory, args=(rank, world_size, args) - ) - worker.start() - workers.append(worker) - - for worker in workers: - worker.join() - assert worker.exitcode == 0 diff --git a/tests/test_distributed/test_with_dask/test_external_memory.py b/tests/test_distributed/test_with_dask/test_external_memory.py new file mode 100644 index 000000000000..cf475d90f294 --- /dev/null +++ b/tests/test_distributed/test_with_dask/test_external_memory.py @@ -0,0 +1,88 @@ +from typing import List, cast + +import numpy as np +from distributed import Client, Scheduler, Worker, get_worker +from distributed.utils_test import gen_cluster + +import xgboost as xgb +from xgboost import testing as tm +from xgboost.compat import concat + + +def run_external_memory(worker_id: int, n_workers: int, comm_args: dict) -> None: + n_samples_per_batch = 32 + n_features = 4 + n_batches = 16 + use_cupy = False + + n_threads = get_worker().state.nthreads + with xgb.collective.CommunicatorContext(dmlc_communicator="rabit", **comm_args): + it = tm.IteratorForTest( + *tm.make_batches( + n_samples_per_batch, + n_features, + n_batches, + use_cupy, + random_state=worker_id, + ), + cache="cache", + ) + Xy = xgb.DMatrix(it, nthread=n_threads) + results: xgb.callback.TrainingCallback.EvalsLog = {} + booster = xgb.train( + {"tree_method": "hist", "nthread": n_threads}, + Xy, + evals=[(Xy, "Train")], + num_boost_round=32, + evals_result=results, + ) + assert tm.non_increasing(cast(List[float], results["Train"]["rmse"])) + + lx, ly, lw = [], [], [] + for i in range(n_workers): + x, y, w = tm.make_batches( + n_samples_per_batch, + n_features, + n_batches, + use_cupy, + random_state=i, + ) + lx.extend(x) + ly.extend(y) + lw.extend(w) + + X = concat(lx) + yconcat = concat(ly) + wconcat = concat(lw) + Xy = xgb.DMatrix(X, yconcat, wconcat, nthread=n_threads) + + results_local: xgb.callback.TrainingCallback.EvalsLog = {} + booster = xgb.train( + {"tree_method": "hist", "nthread": n_threads}, + Xy, + evals=[(Xy, "Train")], + num_boost_round=32, + evals_result=results_local, + ) + np.testing.assert_allclose( + results["Train"]["rmse"], results_local["Train"]["rmse"], rtol=1e-4 + ) + + +@gen_cluster(client=True) +async def test_external_memory( + client: Client, s: Scheduler, a: Worker, b: Worker +) -> None: + workers = tm.get_client_workers(client) + args = await client.sync( + xgb.dask._get_rabit_args, + len(workers), + None, + client, + ) + n_workers = len(workers) + + futs = client.map( + run_external_memory, range(n_workers), n_workers=n_workers, comm_args=args + ) + await client.gather(futs) From d6ae91b4e3c00c559e1fef181c116b3b0c2d4113 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 5 Jul 2024 22:55:13 +0800 Subject: [PATCH 4/4] pylint. --- python-package/xgboost/testing/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index af39e5602c94..e0096c89c9a8 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -248,7 +248,7 @@ def as_arrays( return X, y, w -def make_batches( +def make_batches( # pylint: disable=too-many-arguments,too-many-locals n_samples_per_batch: int, n_features: int, n_batches: int,