Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/branch-24.02' into spdlog_1.12…
Browse files Browse the repository at this point in the history
…_fmt_10
  • Loading branch information
bdice committed Nov 30, 2023
2 parents b2cf4df + 63f98eb commit 33d30aa
Show file tree
Hide file tree
Showing 17 changed files with 839 additions and 194 deletions.
12 changes: 10 additions & 2 deletions cpp/src/svm/kernelcache.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,16 @@ class BatchCache : public raft::cache::Cache<math_t> {
RAFT_CUDA_TRY(cudaMemsetAsync(tmp_buffer, 0, n_ws * 2 * sizeof(int), stream));

// Init cub buffers
cub::DeviceRadixSort::SortKeys(
NULL, d_temp_storage_size, tmp_buffer, tmp_buffer, n_ws, 0, sizeof(int) * 8, stream);
cub::DeviceRadixSort::SortPairs(NULL,
d_temp_storage_size,
tmp_buffer,
tmp_buffer,
tmp_buffer,
tmp_buffer,
n_ws,
0,
sizeof(int) * 8,
stream);
d_temp_storage.resize(d_temp_storage_size, stream);
}

Expand Down
4 changes: 3 additions & 1 deletion python/cuml/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,6 +31,7 @@

from cuml.internals.input_utils import input_to_cuml_array
from cuml.internals.input_utils import input_to_host_array
from cuml.internals.input_utils import input_to_host_array_with_sparse_support

from cuml.internals.memory_utils import rmm_cupy_ary
from cuml.internals.memory_utils import set_global_output_type
Expand Down Expand Up @@ -59,6 +60,7 @@
"has_scipy",
"input_to_cuml_array",
"input_to_host_array",
"input_to_host_array_with_sparse_support",
"rmm_cupy_ary",
"set_global_output_type",
"using_device_type",
Expand Down
3 changes: 2 additions & 1 deletion python/cuml/common/doc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@
" Ignored when return_sparse=False.\n"
" If True, values in the inverse transform below this parameter\n"
" are clipped to 0.",
None: "{name} : None\n"
" Ignored. This parameter exists for compatibility only.",
}

_parameter_possible_values = [
Expand Down Expand Up @@ -222,7 +224,6 @@ def deco(func):
if (
"X" in params or "y" in params or parameters
) and not skip_parameters_heading:

func.__doc__ += "\nParameters\n----------\n"

# Check if we want to prepend the parameters
Expand Down
3 changes: 2 additions & 1 deletion python/cuml/dask/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
np = cpu_only_import("numpy")


dask_cudf = gpu_only_import("dask_cudf")
dcDataFrame = gpu_only_import_from("dask_cudf.core", "DataFrame")


Expand Down Expand Up @@ -343,7 +344,7 @@ def _run_parallel_func(
if output_futures:
return self.client.compute(preds)
else:
output = dask.dataframe.from_delayed(preds)
output = dask_cudf.from_delayed(preds)
return output if delayed else output.persist()
else:
raise ValueError(
Expand Down
3 changes: 2 additions & 1 deletion python/cuml/dask/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.
#

from cuml.dask.preprocessing.encoders import OneHotEncoder, OrdinalEncoder
from cuml.dask.preprocessing.label import LabelBinarizer
from cuml.dask.preprocessing.encoders import OneHotEncoder
from cuml.dask.preprocessing.LabelEncoder import LabelEncoder

__all__ = [
"LabelBinarizer",
"OneHotEncoder",
"OrdinalEncoder",
"LabelEncoder",
]
173 changes: 138 additions & 35 deletions python/cuml/dask/preprocessing/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from dask_cudf.core import Series as daskSeries
from collections.abc import Sequence

from cuml.common import with_cupy_rmm
from cuml.dask.common.base import (
BaseEstimator,
DelayedInverseTransformMixin,
DelayedTransformMixin,
)
from cuml.internals.safe_imports import gpu_only_import_from, gpu_only_import
from dask_cudf.core import Series as daskSeries
from toolz import first

from cuml.dask.common.base import BaseEstimator
from cuml.dask.common.base import DelayedTransformMixin
from cuml.dask.common.base import DelayedInverseTransformMixin
dask_cudf = gpu_only_import("dask_cudf")
dcDataFrame = gpu_only_import_from("dask_cudf.core", "DataFrame")

from toolz import first

from collections.abc import Sequence
from cuml.internals.safe_imports import gpu_only_import_from
class DelayedFitTransformMixin:
def fit_transform(self, X, delayed=True):
"""Fit the encoder to X, then transform X. Equivalent to fit(X).transform(X).
dcDataFrame = gpu_only_import_from("dask_cudf.core", "DataFrame")
Parameters
----------
X : Dask cuDF DataFrame or CuPy backed Dask Array
The data to encode.
delayed : bool (default = True)
Whether to execute as a delayed task or eager.
Returns
-------
out : Dask cuDF DataFrame or CuPy backed Dask Array
Distributed object containing the transformed data
"""
return self.fit(X).transform(X, delayed=delayed)


class OneHotEncoder(
BaseEstimator, DelayedTransformMixin, DelayedInverseTransformMixin
BaseEstimator,
DelayedTransformMixin,
DelayedInverseTransformMixin,
DelayedFitTransformMixin,
):
"""
Encode categorical features as a one-hot numeric array.
Expand Down Expand Up @@ -83,13 +106,9 @@ class OneHotEncoder(
will be denoted as None.
"""

def __init__(self, *, client=None, verbose=False, **kwargs):
super().__init__(client=client, verbose=verbose, **kwargs)

@with_cupy_rmm
def fit(self, X):
"""
Fit a multi-node multi-gpu OneHotEncoder to X.
"""Fit a multi-node multi-gpu OneHotEncoder to X.
Parameters
----------
Expand All @@ -111,10 +130,9 @@ def fit(self, X):

return self

def fit_transform(self, X, delayed=True):
"""
Fit OneHotEncoder to X, then transform X.
Equivalent to fit(X).transform(X).
@with_cupy_rmm
def transform(self, X, delayed=True):
"""Transform X using one-hot encoding.
Parameters
----------
Expand All @@ -126,52 +144,137 @@ def fit_transform(self, X, delayed=True):
Returns
-------
out : Dask cuDF DataFrame or CuPy backed Dask Array
Distributed object containing the transformed data
Distributed object containing the transformed input.
"""
return self.fit(X).transform(X, delayed=delayed)
return self._transform(
X,
n_dims=2,
delayed=delayed,
output_dtype=self._get_internal_model().dtype,
output_collection_type="cupy",
)

@with_cupy_rmm
def transform(self, X, delayed=True):
"""
Transform X using one-hot encoding.
def inverse_transform(self, X, delayed=True):
"""Convert the data back to the original representation. In case unknown
categories are encountered (all zeros in the one-hot encoding), ``None`` is used
to represent this category.
Parameters
----------
X : Dask cuDF DataFrame or CuPy backed Dask Array
The data to encode.
X : CuPy backed Dask Array, shape [n_samples, n_encoded_features]
The transformed data.
delayed : bool (default = True)
Whether to execute as a delayed task or eager.
Returns
-------
out : Dask cuDF DataFrame or CuPy backed Dask Array
Distributed object containing the transformed input.
X_tr : Dask cuDF DataFrame or CuPy backed Dask Array
Distributed object containing the inverse transformed array.
"""
dtype = self._get_internal_model().dtype
return self._inverse_transform(
X,
n_dims=2,
delayed=delayed,
output_dtype=dtype,
output_collection_type=self.datatype,
)


class OrdinalEncoder(
BaseEstimator,
DelayedTransformMixin,
DelayedInverseTransformMixin,
DelayedFitTransformMixin,
):
"""Encode categorical features as an integer array.
The input to this transformer should be an :py:class:`dask_cudf.DataFrame` or a
:py:class:`dask.array.Array` backed by cupy, denoting the unique values taken on by
categorical (discrete) features. The features are converted to ordinal
integers. This results in a single column of integers (0 to n_categories - 1) per
feature.
Parameters
----------
categories : :py:class:`cupy.ndarray` or :py:class`cudf.DataFrameq, default='auto'
Categories (unique values) per feature. All categories are expected to
fit on one GPU.
- 'auto' : Determine categories automatically from the training data.
- DataFrame/ndarray : ``categories[col]`` holds the categories expected
in the feature col.
handle_unknown : {'error', 'ignore'}, default='error'
Whether to raise an error or ignore if an unknown categorical feature is
present during transform (default is to raise). When this parameter is set
to 'ignore' and an unknown category is encountered during transform, the
resulting encoded value would be null when output type is cudf
dataframe.
verbose : int or boolean, default=False
Sets logging level. It must be one of `cuml.common.logger.level_*`. See
:ref:`verbosity-levels` for more info.
"""

@with_cupy_rmm
def fit(self, X):
"""Fit Ordinal to X.
Parameters
----------
X : :py:class:`dask_cudf.DataFrame` or a CuPy backed :py:class:`dask.array.Array`.
shape = (n_samples, n_features) The data to determine the categories of each
feature.
Returns
-------
self
"""
from cuml.preprocessing.ordinalencoder_mg import OrdinalEncoderMG

el = first(X) if isinstance(X, Sequence) else X
self.datatype = (
"cudf" if isinstance(el, (dcDataFrame, daskSeries)) else "cupy"
)

self._set_internal_model(OrdinalEncoderMG(**self.kwargs).fit(X))

return self

@with_cupy_rmm
def transform(self, X, delayed=True):
"""Transform X using ordinal encoding.
Parameters
----------
X : :py:class:`dask_cudf.DataFrame` or cupy backed dask array. The data to
encode.
Returns
-------
X_out :
Transformed input.
"""
return self._transform(
X,
n_dims=2,
delayed=delayed,
output_dtype=self._get_internal_model().dtype,
output_collection_type="cupy",
output_collection_type=self.datatype,
)

@with_cupy_rmm
def inverse_transform(self, X, delayed=True):
"""
Convert the data back to the original representation.
In case unknown categories are encountered (all zeros in the
one-hot encoding), ``None`` is used to represent this category.
"""Convert the data back to the original representation.
Parameters
----------
X : CuPy backed Dask Array, shape [n_samples, n_encoded_features]
The transformed data.
X : :py:class:`dask_cudf.DataFrame` or cupy backed dask array.
delayed : bool (default = True)
Whether to execute as a delayed task or eager.
Returns
-------
X_tr : Dask cuDF DataFrame or CuPy backed Dask Array
X_tr :
Distributed object containing the inverse transformed array.
"""
dtype = self._get_internal_model().dtype
Expand Down
13 changes: 9 additions & 4 deletions python/cuml/internals/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ from cuml.internals.safe_imports import (
np = cpu_only_import('numpy')
nvtx_annotate = gpu_only_import_from("nvtx", "annotate", alt=null_decorator)

from sklearn.utils import estimator_html_repr
try:
from sklearn.utils import estimator_html_repr
except ImportError:
estimator_html_repr = None


import cuml
import cuml.common
Expand Down Expand Up @@ -447,9 +451,10 @@ class Base(TagsMixin,

def _repr_mimebundle_(self, **kwargs):
"""Prepare representations used by jupyter kernels to display estimator"""
output = {"text/plain": repr(self)}
output["text/html"] = estimator_html_repr(self)
return output
if estimator_html_repr is not None:
output = {"text/plain": repr(self)}
output["text/html"] = estimator_html_repr(self)
return output

def set_nvtx_annotations(self):
for func_name in ['fit', 'transform', 'predict', 'fit_transform',
Expand Down
14 changes: 14 additions & 0 deletions python/cuml/internals/input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,20 @@ def input_to_host_array(
return out_data._replace(array=out_data.array.to_output("numpy"))


def input_to_host_array_with_sparse_support(X):
_array_type, is_sparse = determine_array_type_full(X)
if is_sparse:
if _array_type == "cupy":
return SparseCumlArray(X).to_output(output_type="scipy")
elif _array_type == "cuml":
return X.to_output(output_type="scipy")
elif _array_type == "numpy":
return X
else:
raise ValueError(f"Unsupported sparse array type: {_array_type}.")
return input_to_host_array(X).array


def convert_dtype(X, to_dtype=np.float32, legacy=True, safe_dtype=True):
"""
Convert X to be of dtype `dtype`, raising a TypeError
Expand Down
Loading

0 comments on commit 33d30aa

Please sign in to comment.