Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some cleanup of k-means internals #953

Merged
merged 29 commits into from
Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c14602d
Adding optional handle to each public API function (along with example)
cjnolet Oct 24, 2022
1b9f8c8
Fixing style
cjnolet Oct 24, 2022
25cdb3a
Updating examples
cjnolet Oct 24, 2022
e0cb66d
Removing accidentally checked in files
cjnolet Oct 24, 2022
a6780b4
Fixing code blocks
cjnolet Oct 24, 2022
7015cca
Cleaning up kmeans internals (removing explicit const for mdspan, using
cjnolet Oct 25, 2022
e84ae26
Merge branch 'branch-22.12' into imp-2212-kmeans_cleanup
cjnolet Oct 25, 2022
ef230d0
Removing commented out code
cjnolet Oct 25, 2022
af0910e
Adding kmeans module to pylibraft for `update_centroids` function.
cjnolet Oct 25, 2022
6a2a8e7
compute_new_centroids test is returning results. Still need to validate
cjnolet Oct 25, 2022
223a4a4
Fixing sample weight
cjnolet Oct 25, 2022
68d7fb2
Allowing most of the outputs to be optional
cjnolet Oct 26, 2022
af401ae
Fixing python astyle
cjnolet Oct 26, 2022
8de2127
Style checks
cjnolet Oct 26, 2022
27d51fa
FIxing style
cjnolet Oct 26, 2022
6c8b3b4
Adding assertion to pytest for naive solutoin
cjnolet Oct 26, 2022
ec6a18b
Forcing batch size
cjnolet Oct 26, 2022
4a61ec2
Fixing style
cjnolet Oct 26, 2022
d5b8c49
Typo
cjnolet Oct 26, 2022
0cfd7d9
Removing argmin computation from `compute_new-centroids`
cjnolet Oct 26, 2022
e0ed9c0
Adding some vlaidation
cjnolet Oct 26, 2022
c717d61
Fixing style
cjnolet Oct 26, 2022
d1173ca
Fixing doc issue
cjnolet Oct 27, 2022
4ad7294
Removing record.txt from raft-dask
cjnolet Oct 27, 2022
5dd3745
Updating docs
cjnolet Oct 27, 2022
2945660
Merge remote-tracking branch 'rapidsai/branch-22.12' into imp-2212-km…
cjnolet Oct 27, 2022
fa2c8ed
Style
cjnolet Oct 27, 2022
4d8c0f5
Fixing docs
cjnolet Oct 27, 2022
e98bbdf
Review feedback
cjnolet Oct 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ build/
build_prims/
dist/
python/**/**/*.cpp
python/raft/record.txt
python/raft-dask/record.txt
python/pylibraft/record.txt
log
.ipynb_checkpoints
Expand Down
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ if(RAFT_COMPILE_DIST_LIBRARY)
add_library(raft_distance_lib
src/distance/pairwise_distance.cu
src/distance/fused_l2_min_arg.cu
src/distance/update_centroids_float.cu
src/distance/update_centroids_double.cu
src/distance/specializations/detail/canberra.cu
src/distance/specializations/detail/chebyshev.cu
src/distance/specializations/detail/correlation.cu
Expand Down
231 changes: 143 additions & 88 deletions cpp/include/raft/cluster/detail/kmeans.cuh

Large diffs are not rendered by default.

81 changes: 44 additions & 37 deletions cpp/include/raft/cluster/detail/kmeans_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ void countLabels(const raft::handle_t& handle,

template <typename DataT, typename IndexT>
void checkWeight(const raft::handle_t& handle,
const raft::device_vector_view<DataT, IndexT>& weight,
raft::device_vector_view<DataT, IndexT> weight,
rmm::device_uvector<char>& workspace)
{
cudaStream_t stream = handle.get_stream();
Expand Down Expand Up @@ -166,24 +166,24 @@ void checkWeight(const raft::handle_t& handle,
}

template <typename IndexT>
IndexT getDataBatchSize(const KMeansParams& params, IndexT n_samples)
IndexT getDataBatchSize(int batch_samples, IndexT n_samples)
{
auto minVal = std::min(static_cast<IndexT>(params.batch_samples), n_samples);
auto minVal = std::min(static_cast<IndexT>(batch_samples), n_samples);
return (minVal == 0) ? n_samples : minVal;
}

template <typename IndexT>
IndexT getCentroidsBatchSize(const KMeansParams& params, IndexT n_local_clusters)
IndexT getCentroidsBatchSize(int batch_centroids, IndexT n_local_clusters)
{
auto minVal = std::min(static_cast<IndexT>(params.batch_centroids), n_local_clusters);
auto minVal = std::min(static_cast<IndexT>(batch_centroids), n_local_clusters);
return (minVal == 0) ? n_local_clusters : minVal;
}

template <typename DataT, typename ReductionOpT, typename IndexT = int>
void computeClusterCost(const raft::handle_t& handle,
const raft::device_vector_view<DataT, IndexT>& minClusterDistance,
raft::device_vector_view<DataT, IndexT> minClusterDistance,
rmm::device_uvector<char>& workspace,
const raft::device_scalar_view<DataT>& clusterCost,
raft::device_scalar_view<DataT> clusterCost,
ReductionOpT reduction_op)
{
cudaStream_t stream = handle.get_stream();
Expand Down Expand Up @@ -211,9 +211,9 @@ void computeClusterCost(const raft::handle_t& handle,

template <typename DataT, typename IndexT>
void sampleCentroids(const raft::handle_t& handle,
const raft::device_matrix_view<const DataT, IndexT>& X,
const raft::device_vector_view<DataT, IndexT>& minClusterDistance,
const raft::device_vector_view<uint8_t, IndexT>& isSampleCentroid,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_vector_view<DataT, IndexT> minClusterDistance,
raft::device_vector_view<uint8_t, IndexT> isSampleCentroid,
SamplingOp<DataT, IndexT>& select_op,
rmm::device_uvector<DataT>& inRankCp,
rmm::device_uvector<char>& workspace)
Expand Down Expand Up @@ -277,9 +277,9 @@ void sampleCentroids(const raft::handle_t& handle,
// result will be stored in 'pairwiseDistance[n x k]'
template <typename DataT, typename IndexT>
void pairwise_distance_kmeans(const raft::handle_t& handle,
const raft::device_matrix_view<const DataT, IndexT> X,
const raft::device_matrix_view<const DataT, IndexT> centroids,
const raft::device_matrix_view<DataT, IndexT> pairwiseDistance,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<const DataT, IndexT> centroids,
raft::device_matrix_view<DataT, IndexT> pairwiseDistance,
rmm::device_uvector<char>& workspace,
raft::distance::DistanceType metric)
{
Expand All @@ -305,8 +305,8 @@ void pairwise_distance_kmeans(const raft::handle_t& handle,
// in 'out' does not modify the input
template <typename DataT, typename IndexT>
void shuffleAndGather(const raft::handle_t& handle,
const raft::device_matrix_view<const DataT, IndexT>& in,
const raft::device_matrix_view<DataT, IndexT>& out,
raft::device_matrix_view<const DataT, IndexT> in,
raft::device_matrix_view<DataT, IndexT> out,
uint32_t n_samples_to_gather,
uint64_t seed)
{
Expand Down Expand Up @@ -340,24 +340,25 @@ void shuffleAndGather(const raft::handle_t& handle,
template <typename DataT, typename IndexT>
void minClusterAndDistanceCompute(
const raft::handle_t& handle,
const KMeansParams& params,
const raft::device_matrix_view<const DataT, IndexT> X,
const raft::device_matrix_view<const DataT, IndexT> centroids,
const raft::device_vector_view<raft::KeyValuePair<IndexT, DataT>, IndexT> minClusterAndDistance,
const raft::device_vector_view<DataT, IndexT> L2NormX,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<const DataT, IndexT> centroids,
raft::device_vector_view<raft::KeyValuePair<IndexT, DataT>, IndexT> minClusterAndDistance,
raft::device_vector_view<const DataT, IndexT> L2NormX,
rmm::device_uvector<DataT>& L2NormBuf_OR_DistBuf,
raft::distance::DistanceType metric,
int batch_samples,
int batch_centroids,
rmm::device_uvector<char>& workspace)
{
cudaStream_t stream = handle.get_stream();
auto n_samples = X.extent(0);
auto n_features = X.extent(1);
auto n_clusters = centroids.extent(0);
auto metric = params.metric;
// todo(lsugy): change batch size computation when using fusedL2NN!
bool is_fused = metric == raft::distance::DistanceType::L2Expanded ||
metric == raft::distance::DistanceType::L2SqrtExpanded;
auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(params, n_samples);
auto centroidsBatchSize = getCentroidsBatchSize(params, n_clusters);
auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples);
auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters);

if (is_fused) {
L2NormBuf_OR_DistBuf.resize(n_clusters, stream);
Expand All @@ -369,6 +370,9 @@ void minClusterAndDistanceCompute(
true,
stream);
} else {
// TODO: Unless pool allocator is used, passing in a workspace for this
// isn't really increasing performance because this needs to do a re-allocation
// anyways. ref https://github.com/rapidsai/raft/issues/930
L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream);
}

Expand Down Expand Up @@ -403,7 +407,7 @@ void minClusterAndDistanceCompute(
minClusterAndDistance.data_handle() + dIdx, ns);

auto L2NormXView =
raft::make_device_vector_view<DataT, IndexT>(L2NormX.data_handle() + dIdx, ns);
raft::make_device_vector_view<const DataT, IndexT>(L2NormX.data_handle() + dIdx, ns);

if (is_fused) {
workspace.resize((sizeof(int)) * ns, stream);
Expand Down Expand Up @@ -471,24 +475,25 @@ void minClusterAndDistanceCompute(

template <typename DataT, typename IndexT>
void minClusterDistanceCompute(const raft::handle_t& handle,
const KMeansParams& params,
const raft::device_matrix_view<const DataT, IndexT>& X,
const raft::device_matrix_view<DataT, IndexT>& centroids,
const raft::device_vector_view<DataT, IndexT>& minClusterDistance,
const raft::device_vector_view<DataT, IndexT>& L2NormX,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<DataT, IndexT> centroids,
raft::device_vector_view<DataT, IndexT> minClusterDistance,
raft::device_vector_view<DataT, IndexT> L2NormX,
rmm::device_uvector<DataT>& L2NormBuf_OR_DistBuf,
raft::distance::DistanceType metric,
int batch_samples,
int batch_centroids,
rmm::device_uvector<char>& workspace)
{
cudaStream_t stream = handle.get_stream();
auto n_samples = X.extent(0);
auto n_features = X.extent(1);
auto n_clusters = centroids.extent(0);
auto metric = params.metric;

bool is_fused = metric == raft::distance::DistanceType::L2Expanded ||
metric == raft::distance::DistanceType::L2SqrtExpanded;
auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(params, n_samples);
auto centroidsBatchSize = getCentroidsBatchSize(params, n_clusters);
auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples);
auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters);

if (is_fused) {
L2NormBuf_OR_DistBuf.resize(n_clusters, stream);
Expand Down Expand Up @@ -597,11 +602,11 @@ void minClusterDistanceCompute(const raft::handle_t& handle,
template <typename DataT, typename IndexT>
void countSamplesInCluster(const raft::handle_t& handle,
const KMeansParams& params,
const raft::device_matrix_view<const DataT, IndexT>& X,
const raft::device_vector_view<DataT, IndexT> L2NormX,
const raft::device_matrix_view<DataT, IndexT> centroids,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_vector_view<const DataT, IndexT> L2NormX,
raft::device_matrix_view<DataT, IndexT> centroids,
rmm::device_uvector<char>& workspace,
const raft::device_vector_view<DataT, IndexT> sampleCountInCluster)
raft::device_vector_view<DataT, IndexT> sampleCountInCluster)
{
cudaStream_t stream = handle.get_stream();
auto n_samples = X.extent(0);
Expand All @@ -623,12 +628,14 @@ void countSamplesInCluster(const raft::handle_t& handle,
// centroid) and 'value' is the distance between the sample 'X[i]' and the
// 'centroid[key]'
detail::minClusterAndDistanceCompute(handle,
params,
X,
(raft::device_matrix_view<const DataT, IndexT>)centroids,
minClusterAndDistance.view(),
L2NormX,
L2NormBuf_OR_DistBuf,
params.metric,
params.batch_samples,
params.batch_centroids,
workspace);

// Using TransformInputIteratorT to dereference an array of raft::KeyValuePair
Expand Down
Loading