Skip to content

Commit

Permalink
Merge branch 'branch-23.12' into html-repr
Browse files Browse the repository at this point in the history
  • Loading branch information
dantegd authored Oct 31, 2023
2 parents 20db468 + a01181f commit 3828291
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 31 deletions.
55 changes: 28 additions & 27 deletions docs/source/execution_device_interoperability.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
"source": [
"# cuML on GPU and CPU\n",
"\n",
"cuML is a Scikit-learn-based suite of fast, GPU-accelerated machine learning algorithms designed for data science and analytical tasks. Starting with version 23.10, a new version of cuML can also be run on CPU systems, increasing its ease of use (without code changes) in the following manners: \n",
"cuML is a Scikit-learn-like suite of fast, GPU-accelerated machine learning algorithms designed for data science and analytical tasks.\n",
"\n",
"- Allow users to prototype in systems without GPUs. \n",
"- Allow library integrations without the need of dispatching and boilerplate code. \n",
"- Allow users to train on one type of system and infer with the other in a subset of estimators (that will grow with each version). \n",
"- Provide compatibility with the GPU/CPU open source pydata ecosystem.\n",
"Starting with version 23.10, cuML provides both GPU-based and CPU-based execution capabilities with zero code change required to switch between them. This unified CPU/GPU cuML: \n",
"\n",
"The majority of estimators of cuML can run in both CPU and GPU systems, with a subset of them allowing exporting models between GPU and CPU systems. The following table shows support for the most common estimators: \n",
"- Allows users to prototype in systems without GPUs. \n",
"- Allows library integrations without the need for dispatching and boilerplate code. \n",
"- Allows users to train on one type of system and infer with the other for a subset of estimators (that will expand over time). \n",
"- Provides compatibility with the broader GPU/CPU open source pydata ecosystem.\n",
"\n",
"The majority of estimators of cuML can run in both CPU and GPU systems, with a subset of them supporting exporting models between GPU and CPU systems. The following table shows support for the most common estimators: \n",
"\n",
"| Category | Algorithm | Supports Execution on CPU | Supports Exporting between CPU and GPU | \n",
"| --- | --- | --- | --- |\n",
Expand Down Expand Up @@ -45,7 +47,9 @@
"| **Time Series** | Holt-Winters Exponential Smoothing | No | No |\n",
"| | Auto-regressive Integrated Moving Average (ARIMA) | No | No |\n",
"\n",
"This allows the same code to be guaranteed to run in both GPU and CPU systems. Version 23.12 is scheduled to add the following algorithms: Random Forest and Support Vector Machine estimators. \n",
"This allows the same code to be guaranteed to run in both GPU and CPU systems. Version 23.12 is scheduled to add the following algorithms:\n",
"- Random Forest\n",
"- Support Vector Machine estimators\n",
"\n"
]
},
Expand All @@ -57,23 +61,23 @@
"\n",
"## Installation\n",
"\n",
"For GPU systems, cuML still follows the [RAPIDS requirements] and nothing has changed for installing it. The cuML package and wheels are universal and can run in both GPU and CPU modes. For installing in CPU systems, similar to other packages it can be installed from conda/mamba with:\n",
"For GPU systems, cuML still follows the [RAPIDS requirements](https://rapids.ai/#quick-start). The cuML package and wheels are universal and can run in both GPU and CPU modes. To use cuML in CPU-only systems, you can install using conda/mamba with:\n",
"\n",
"```bash\n",
"mamba install -c rapidsai -c nvidia -c conda-forge cuml-cpu=23.10 \n",
"# mamba install -c rapidsai-nightly -c nvidia -c conda-forge cuml-cpu=23.12 # for nightly builds\n",
"```\n",
"\n",
"- cuML 23.10 supports Linux and WSL2 on GPU and CPU systems using conda. \n",
"- cuML 23.12 will bring support for pip wheels and macos support for CPU execution. \n",
"- cuML 23.12 will bring support for pip wheels and MacOS support for CPU execution. \n",
"\n",
"### How to Use\n",
"\n",
"There are two main ways to use the CPU capabilities of cuML:\n",
"\n",
"#### 1. Using CPU Package directly\n",
"\n",
"The CPU package, `cuml-cpu` is a subset of the `cuml` package, so besides the difference in installation there is no changes needed to the code of supported estimators to run code. For example, the following script can be run both in a system with GPU and `cuml`, as well as a system without GPU and `cuml-cpu`:"
"The CPU package, `cuml-cpu` is a subset of the `cuml` package, so there are zero code changes required to run the code when using a CPU-only system. For example, the following script can be run both in a system with GPU and `cuml`, as well as a system without GPU and `cuml-cpu`:"
]
},
{
Expand Down Expand Up @@ -110,7 +114,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"This allows to prototype on CPU systems and then run code on GPU servers, or the other way around. Some estimators support training on one type of system and then exporting models to the other type, as can be seen in [the corresponding section](#Cross-Device-Training-and-Inference-Serialization)."
"This allows easy prototyping on CPU systems and running production code on GPU servers, or the other way around. Some estimators support training on one type of system and then exporting models to the other type, as noted above and explained by example in [the corresponding section](#Cross-Device-Training-and-Inference-Serialization)."
]
},
{
Expand All @@ -119,7 +123,7 @@
"source": [
"#### 2. Managing Execution Platform with GPU package\n",
"\n",
"Additionally to allowing the same code to be run in CPU systems, users can control which device executes parts of the code. So in addition to the first example that can just be run in a CPU system with `cuml-cpu`, a system with the full cuML can execute in CPU mode as well. \n",
"In addition to allowing the zero-code change execution in CPU systems, users can also manually control which device executes parts of the code when using a system with the full cuML.\n",
"\n",
"For example, using the following data: "
]
Expand Down Expand Up @@ -155,7 +159,7 @@
"source": [
"There are two ways to control the execution of the code:\n",
"\n",
"#### a) `using_device_type` context manager:"
"#### a) `using_device_type` context manager"
]
},
{
Expand All @@ -177,9 +181,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"This allows to prototype but also to run different estimators on different devices, for example in the case where data is small so that moving the data around wouldn't allow the GPU to accelerate an estimator. \n",
"This makes it easy to prototype and run different estimators on different devices, for example in the case where data is small so that moving the data around wouldn't allow the GPU to accelerate an estimator. \n",
"\n",
"Additionally, it allows to run estimators using unsupported parameter: "
"It also allows running estimators using unsupported parameters: "
]
},
{
Expand All @@ -201,14 +205,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"An upcoming feature will allow for this to also dispatch automatically. This can be very useful for library integrators, so that if users use parameters not supported on GPUs, the code automatically will dispatch to a CPU implementation. "
"An upcoming feature will allow for this dispatch to occur automatically under-the-hood. This can be very useful for when integrating cuML into other libraries, so that if users use parameters not supported on GPUs, the code automatically will dispatch to a CPU implementation. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### b) Global configuration. "
"#### b) Global configuration with `set_global_device_type`"
]
},
{
Expand Down Expand Up @@ -248,7 +252,7 @@
"source": [
"## Cross Device Training and Inference Serialization\n",
"\n",
"As stated before, a subset of the estimators that can be executed on the CPU, also allow to serialize estimators trained on one type of device (CPU or GPU) and then deserialize it on the other one. \n",
"As stated above, a subset of the estimators support training on one type of device (CPU or GPU), serializing the trained model, and then deserializing and executing it on the other type of device. \n",
"\n",
"To do this, a simple API is provided. For example, To train a model on GPU but deploy it on CPU, first, train the estimator on device and save it to disk:"
]
Expand Down Expand Up @@ -291,20 +295,17 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusions\n",
"## Conclusion\n",
"\n",
"cuML's CPU capabilities are designed to facilitate different usecases, lower the requirements to use the capabilities of cuML, as well as increasing the flexibility and capabilities of integration and deployment of the library. \n",
"cuML's CPU capabilities are designed to facilitate different use cases, lower the barriers to using the capabilities of cuML, an streamline integrating cuML into other tools and deploying models. \n",
"\n",
"Upcoming versions of cuML will increase the supported estimators, both for CPU execution as well as serializing/exporting models between systems with and without GPUs. "
"Upcoming versions of cuML will expand the supported estimators, both for CPU execution as well as serializing/exporting models between systems with and without GPUs. "
]
}
],
"metadata": {
"interpreter": {
"hash": "35840739db47a5016f18b089945bf3e154a2dca6d71cfb13687d370b69a146e3"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3.10.12 ('cuml_dev')",
"language": "python",
"name": "python3"
},
Expand All @@ -318,11 +319,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.10.12"
},
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
"hash": "975233ed6ddd7eb5f50db124c7eb6e9abd7f2428099fbb1c703209662350014b"
}
}
},
Expand Down
8 changes: 7 additions & 1 deletion python/cuml/benchmark/automated/dask/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from dask_cuda import initialize
from dask_cuda import LocalCUDACluster
from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
from dask.distributed import Client

enable_tcp_over_ucx = True
Expand All @@ -28,7 +29,11 @@
@pytest.fixture(scope="module")
def cluster():

cluster = LocalCUDACluster(protocol="tcp", scheduler_port=0)
cluster = LocalCUDACluster(
protocol="tcp",
scheduler_port=0,
worker_class=IncreasedCloseTimeoutNanny,
)
yield cluster
cluster.close()

Expand All @@ -54,6 +59,7 @@ def ucx_cluster():
enable_tcp_over_ucx=enable_tcp_over_ucx,
enable_nvlink=enable_nvlink,
enable_infiniband=enable_infiniband,
worker_class=IncreasedCloseTimeoutNanny,
)
yield cluster
cluster.close()
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/svm/linear.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ cdef class LinearSVMWrapper:
if self.dtype != np.float32 and self.dtype != np.float64:
raise TypeError('Input data type must be float32 or float64')

cdef uintptr_t Xptr = <uintptr_t>X.ptr
cdef uintptr_t yptr = <uintptr_t>y.ptr
cdef uintptr_t Xptr = <uintptr_t>X.ptr if X is not None else 0
cdef uintptr_t yptr = <uintptr_t>y.ptr if y is not None else 0
cdef uintptr_t swptr = <uintptr_t>sampleWeight.ptr \
if sampleWeight is not None else 0
cdef size_t nCols = 0
Expand Down
8 changes: 7 additions & 1 deletion python/cuml/tests/dask/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from dask_cuda import initialize
from dask_cuda import LocalCUDACluster
from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
from dask.distributed import Client

enable_tcp_over_ucx = True
Expand All @@ -14,7 +15,11 @@
@pytest.fixture(scope="module")
def cluster():

cluster = LocalCUDACluster(protocol="tcp", scheduler_port=0)
cluster = LocalCUDACluster(
protocol="tcp",
scheduler_port=0,
worker_class=IncreasedCloseTimeoutNanny,
)
yield cluster
cluster.close()

Expand All @@ -40,6 +45,7 @@ def ucx_cluster():
enable_tcp_over_ucx=enable_tcp_over_ucx,
enable_nvlink=enable_nvlink,
enable_infiniband=enable_infiniband,
worker_class=IncreasedCloseTimeoutNanny,
)
yield cluster
cluster.close()
Expand Down
37 changes: 37 additions & 0 deletions python/cuml/tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,43 @@ def assert_model(pickled_model, data):
pickle_save_load(tmpdir, create_mod, assert_model)


@pytest.mark.parametrize("datatype", [np.float32, np.float64])
@pytest.mark.parametrize(
"params", [{"probability": True}, {"probability": False}]
)
@pytest.mark.parametrize("multiclass", [True, False])
def test_linear_svc_pickle(tmpdir, datatype, params, multiclass):
result = {}

def create_mod():
model = cuml.svm.LinearSVC(**params)
iris = load_iris()
iris_selection = np.random.RandomState(42).choice(
[True, False], 150, replace=True, p=[0.75, 0.25]
)
X_train = iris.data[iris_selection]
y_train = iris.target[iris_selection]
if not multiclass:
y_train = (y_train > 0).astype(datatype)
data = [X_train, y_train]
result["model"] = model.fit(X_train, y_train)
return model, data

def assert_model(pickled_model, data):
if result["model"].probability:
print("Comparing probabilistic LinearSVC")
compare_probabilistic_svm(
result["model"], pickled_model, data[0], data[1], 0, 0
)
else:
print("comparing base LinearSVC")
pred_before = result["model"].predict(data[0])
pred_after = pickled_model.predict(data[0])
assert array_equal(pred_before, pred_after)

pickle_save_load(tmpdir, create_mod, assert_model)


@pytest.mark.parametrize("datatype", [np.float32, np.float64])
@pytest.mark.parametrize("nrows", [unit_param(500)])
@pytest.mark.parametrize("ncols", [unit_param(16)])
Expand Down

0 comments on commit 3828291

Please sign in to comment.