Skip to content

Commit

Permalink
Merge branch 'branch-23.12' into 2312-dbg-dask-timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
dantegd authored Nov 30, 2023
2 parents cc82a80 + a1d1fb6 commit da4d4ca
Show file tree
Hide file tree
Showing 26 changed files with 1,247 additions and 233 deletions.
31 changes: 31 additions & 0 deletions cpp/include/cuml/linear_model/qn_mg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,37 @@ void qnFit(raft::handle_t& handle,
float* f,
int* num_iters);

/**
* @brief support sparse vectors (Compressed Sparse Row format) for MNMG logistic regression fit
* using quasi newton methods
* @param[in] handle: the internal cuml handle object
* @param[in] input_values: vector holding non-zero values of all partitions for that rank
* @param[in] input_cols: vector holding column indices of non-zero values of all partitions for
* that rank
* @param[in] input_row_ids: vector holding row pointers of non-zero values of all partitions for
* that rank
* @param[in] X_nnz: the number of non-zero values of that rank
* @param[in] input_desc: PartDescriptor object for the input
* @param[in] labels: labels data
* @param[out] coef: learned coefficients
* @param[in] pams: model parameters
* @param[in] n_classes: number of outputs (number of classes or `1` for regression)
* @param[out] f: host pointer holding the final objective value
* @param[out] num_iters: host pointer holding the actual number of iterations taken
*/
void qnFitSparse(raft::handle_t& handle,
std::vector<Matrix::Data<float>*>& input_values,
int* input_cols,
int* input_row_ids,
int X_nnz,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<float>*>& labels,
float* coef,
const qn_params& pams,
int n_classes,
float* f,
int* num_iters);

}; // namespace opg
}; // namespace GLM
}; // namespace ML
21 changes: 14 additions & 7 deletions cpp/src/glm/qn/mg/glm_base_mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <raft/core/comms.hpp>
#include <raft/core/handle.hpp>
#include <raft/linalg/add.cuh>
#include <raft/linalg/multiply.cuh>
#include <raft/util/cudart_utils.hpp>

Expand Down Expand Up @@ -112,34 +113,42 @@ struct GLMWithDataMG : ML::GLM::detail::GLMWithData<T, GLMObjective> {
T* dev_scalar,
cudaStream_t stream)
{
raft::comms::comms_t const& communicator = raft::resource::get_comms(*(this->handle_p));
SimpleDenseMat<T> W(wFlat.data, this->C, this->dims);
SimpleDenseMat<T> G(gradFlat.data, this->C, this->dims);
SimpleVec<T> lossVal(dev_scalar, 1);

// Ensure the same coefficients on all GPU
communicator.bcast(wFlat.data, this->C * this->dims, 0, stream);
communicator.sync_stream(stream);

// apply regularization
auto regularizer_obj = this->objective;
auto lossFunc = regularizer_obj->loss;
auto reg = regularizer_obj->reg;
G.fill(0, stream);
float reg_host = 0;
T reg_host = 0;
if (reg->l2_penalty != 0) {
reg->reg_grad(dev_scalar, G, W, lossFunc->fit_intercept, stream);
raft::update_host(&reg_host, dev_scalar, 1, stream);
// note: avoid syncing here because there's a sync before reg_host is used.
raft::resource::sync_stream(*(this->handle_p));
}

// apply linearFwd, getLossAndDz, linearBwd
ML::GLM::detail::linearFwd(
lossFunc->handle, *(this->Z), *(this->X), W); // linear part: forward pass

raft::comms::comms_t const& communicator = raft::resource::get_comms(*(this->handle_p));

lossFunc->getLossAndDZ(dev_scalar, *(this->Z), *(this->y), stream); // loss specific part

// normalize local loss before allreduce sum
T factor = 1.0 * (*this->y).len / this->n_samples;
raft::linalg::multiplyScalar(dev_scalar, dev_scalar, factor, 1, stream);

// GPUs calculates reg_host independently and may get values that show tiny divergence.
// Take the averaged reg_host to avoid the divergence.
T reg_factor = reg_host / this->n_ranks;
raft::linalg::addScalar(dev_scalar, dev_scalar, reg_factor, 1, stream);

communicator.allreduce(dev_scalar, dev_scalar, 1, raft::comms::op_t::SUM, stream);
communicator.sync_stream(stream);

Expand All @@ -154,11 +163,9 @@ struct GLMWithDataMG : ML::GLM::detail::GLMWithData<T, GLMObjective> {
communicator.allreduce(G.data, G.data, this->C * this->dims, raft::comms::op_t::SUM, stream);
communicator.sync_stream(stream);

float loss_host;
T loss_host;
raft::update_host(&loss_host, dev_scalar, 1, stream);
raft::resource::sync_stream(*(this->handle_p));
loss_host += reg_host;
lossVal.fill(loss_host + reg_host, stream);

return loss_host;
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/glm/qn/mg/qn_mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ inline void qn_fit_x_mg(const raft::handle_t& handle,

switch (pams.loss) {
case QN_LOSS_LOGISTIC: {
ASSERT(C == 2, "qn_mg.cuh: logistic loss invalid C");
ASSERT(C > 0, "qn_mg.cuh: logistic loss invalid C");
ML::GLM::detail::LogisticLoss<T> loss(handle, D, pams.fit_intercept);
ML::GLM::opg::qn_fit_mg<T, decltype(loss)>(
handle, pams, loss, X, y, Z, w0_data, f, num_iters, n_samples, rank, n_ranks);
Expand Down
73 changes: 73 additions & 0 deletions cpp/src/glm/qn_mg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include <vector>
using namespace MLCommon;

#include <iostream>

namespace ML {
namespace GLM {
namespace opg {
Expand Down Expand Up @@ -172,6 +174,77 @@ void qnFit(raft::handle_t& handle,
handle, input_data, input_desc, labels, coef, pams, X_col_major, n_classes, f, num_iters);
}

template <typename T, typename I>
void qnFitSparse_impl(const raft::handle_t& handle,
const qn_params& pams,
T* X_values,
I* X_cols,
I* X_row_ids,
I X_nnz,
T* y,
size_t N,
size_t D,
size_t C,
T* w0,
T* f,
int* num_iters,
size_t n_samples,
int rank,
int n_ranks)
{
auto X_simple = SimpleSparseMat<T>(X_values, X_cols, X_row_ids, X_nnz, N, D);

ML::GLM::opg::qn_fit_x_mg(handle,
pams,
X_simple,
y,
C,
w0,
f,
num_iters,
n_samples,
rank,
n_ranks); // ignore sample_weight, svr_eps
return;
}

void qnFitSparse(raft::handle_t& handle,
std::vector<Matrix::Data<float>*>& input_values,
int* input_cols,
int* input_row_ids,
int X_nnz,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<float>*>& labels,
float* coef,
const qn_params& pams,
int n_classes,
float* f,
int* num_iters)
{
RAFT_EXPECTS(input_values.size() == 1,
"qn_mg.cu currently does not accept more than one input matrix");

auto data_input_values = input_values[0];
auto data_y = labels[0];

qnFitSparse_impl<float, int>(handle,
pams,
data_input_values->ptr,
input_cols,
input_row_ids,
X_nnz,
data_y->ptr,
input_desc.totalElementsOwnedBy(input_desc.rank),
input_desc.N,
n_classes,
coef,
f,
num_iters,
input_desc.M,
input_desc.rank,
input_desc.uniqueRanks().size());
}

}; // namespace opg
}; // namespace GLM
}; // namespace ML
12 changes: 10 additions & 2 deletions cpp/src/svm/kernelcache.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,16 @@ class BatchCache : public raft::cache::Cache<math_t> {
RAFT_CUDA_TRY(cudaMemsetAsync(tmp_buffer, 0, n_ws * 2 * sizeof(int), stream));

// Init cub buffers
cub::DeviceRadixSort::SortKeys(
NULL, d_temp_storage_size, tmp_buffer, tmp_buffer, n_ws, 0, sizeof(int) * 8, stream);
cub::DeviceRadixSort::SortPairs(NULL,
d_temp_storage_size,
tmp_buffer,
tmp_buffer,
tmp_buffer,
tmp_buffer,
n_ws,
0,
sizeof(int) * 8,
stream);
d_temp_storage.resize(d_temp_storage_size, stream);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1147,8 +1147,8 @@ class make_column_selector:
Examples
--------
>>> from cuml.preprocessing import StandardScaler, OneHotEncoder
>>> from cuml.preprocessing import make_column_transformer
>>> from cuml.preprocessing import make_column_selector
>>> from cuml.compose import make_column_transformer
>>> from cuml.compose import make_column_selector
>>> import cupy as cp
>>> import cudf # doctest: +SKIP
>>> X = cudf.DataFrame({'city': ['London', 'London', 'Paris', 'Sallisaw'],
Expand Down
4 changes: 3 additions & 1 deletion python/cuml/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,6 +31,7 @@

from cuml.internals.input_utils import input_to_cuml_array
from cuml.internals.input_utils import input_to_host_array
from cuml.internals.input_utils import input_to_host_array_with_sparse_support

from cuml.internals.memory_utils import rmm_cupy_ary
from cuml.internals.memory_utils import set_global_output_type
Expand Down Expand Up @@ -59,6 +60,7 @@
"has_scipy",
"input_to_cuml_array",
"input_to_host_array",
"input_to_host_array_with_sparse_support",
"rmm_cupy_ary",
"set_global_output_type",
"using_device_type",
Expand Down
3 changes: 2 additions & 1 deletion python/cuml/common/doc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@
" Ignored when return_sparse=False.\n"
" If True, values in the inverse transform below this parameter\n"
" are clipped to 0.",
None: "{name} : None\n"
" Ignored. This parameter exists for compatibility only.",
}

_parameter_possible_values = [
Expand Down Expand Up @@ -222,7 +224,6 @@ def deco(func):
if (
"X" in params or "y" in params or parameters
) and not skip_parameters_heading:

func.__doc__ += "\nParameters\n----------\n"

# Check if we want to prepend the parameters
Expand Down
3 changes: 2 additions & 1 deletion python/cuml/dask/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
np = cpu_only_import("numpy")


dask_cudf = gpu_only_import("dask_cudf")
dcDataFrame = gpu_only_import_from("dask_cudf.core", "DataFrame")


Expand Down Expand Up @@ -343,7 +344,7 @@ def _run_parallel_func(
if output_futures:
return self.client.compute(preds)
else:
output = dask.dataframe.from_delayed(preds)
output = dask_cudf.from_delayed(preds)
return output if delayed else output.persist()
else:
raise ValueError(
Expand Down
27 changes: 25 additions & 2 deletions python/cuml/dask/linear_model/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from raft_dask.common.comms import get_raft_comm_state
from dask.distributed import get_worker

from cuml.common.sparse_utils import is_sparse, has_scipy
from cuml.dask.common import parts_to_ranks
from cuml.dask.common.input_utils import DistributedDataHandler, concatenate
from raft_dask.common.comms import Comms
Expand All @@ -29,7 +30,9 @@
from cuml.internals.safe_imports import gpu_only_import

cp = gpu_only_import("cupy")
cupyx = gpu_only_import("cupyx")
np = cpu_only_import("numpy")
scipy = cpu_only_import("scipy")


class LogisticRegression(LinearRegression):
Expand Down Expand Up @@ -172,13 +175,33 @@ def _create_model(sessionId, datatype, **kwargs):

@staticmethod
def _func_fit(f, data, n_rows, n_cols, partsToSizes, rank):
inp_X = concatenate([X for X, _ in data])
if is_sparse(data[0][0]) is False:
inp_X = concatenate([X for X, _ in data])

elif has_scipy() and scipy.sparse.isspmatrix(data[0][0]):
inp_X = scipy.sparse.vstack([X for X, _ in data])

elif cupyx.scipy.sparse.isspmatrix(data[0][0]):
inp_X = cupyx.scipy.sparse.vstack([X for X, _ in data])

else:
raise ValueError(
"input matrix must be dense, scipy sparse, or cupy sparse"
)

inp_y = concatenate([y for _, y in data])
n_ranks = max([p[0] for p in partsToSizes]) + 1
aggregated_partsToSizes = [[i, 0] for i in range(n_ranks)]
for p in partsToSizes:
aggregated_partsToSizes[p[0]][1] += p[1]

return f.fit(
ret_status = f.fit(
[(inp_X, inp_y)], n_rows, n_cols, aggregated_partsToSizes, rank
)

if len(f.classes_) == 1:
raise ValueError(
f"This solver needs samples of at least 2 classes in the data, but the data contains only one class: {f.classes_[0]}"
)

return ret_status
3 changes: 2 additions & 1 deletion python/cuml/dask/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.
#

from cuml.dask.preprocessing.encoders import OneHotEncoder, OrdinalEncoder
from cuml.dask.preprocessing.label import LabelBinarizer
from cuml.dask.preprocessing.encoders import OneHotEncoder
from cuml.dask.preprocessing.LabelEncoder import LabelEncoder

__all__ = [
"LabelBinarizer",
"OneHotEncoder",
"OrdinalEncoder",
"LabelEncoder",
]
Loading

0 comments on commit da4d4ca

Please sign in to comment.