Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add public API and tests for hierarchical balanced k-means #1113

Merged
merged 53 commits into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
8930877
Add kmeans balanced public API and modify benchmark
Nyrio Dec 14, 2022
1b66ae0
kmeans_balanced unit test
Nyrio Dec 14, 2022
9efbd30
Add support for double
Nyrio Dec 15, 2022
876ed4e
Don't pass redundant stream argument when passing handle
Nyrio Dec 15, 2022
0f3ab9b
Replace uint32_t with IdxT where relevant
Nyrio Dec 15, 2022
289270d
Merge remote-tracking branch 'origin/branch-23.02' into fea-kmeans-ba…
Nyrio Dec 15, 2022
6a57b5e
Use core operators
Nyrio Dec 15, 2022
ef7ffbf
Add parameter structure for kmeans balanced, fix integer type issues,…
Nyrio Dec 16, 2022
c1c1faa
Testing and cleanup
Nyrio Dec 19, 2022
0bc3143
Add test cases for conversion + fix bug
Nyrio Dec 19, 2022
4622f1f
Document k-means balanced API
Nyrio Dec 19, 2022
f32b7f6
Use new API in IVF-Flat and IVF-PQ, remove old API
Nyrio Dec 20, 2022
77ab474
Merge remote-tracking branch 'origin/branch-23.02' into fea-kmeans-ba…
Nyrio Jan 4, 2023
b61d8a8
Casting to avoid template type inference failures
Nyrio Jan 4, 2023
c18b292
Snake-case kmeans_balanced_params
Nyrio Jan 9, 2023
78773a9
View fixes
Nyrio Jan 9, 2023
d493ded
Merge remote-tracking branch 'origin/branch-23.02' into fea-kmeans-ba…
Nyrio Jan 9, 2023
2f1fa20
Post-merge fix
Nyrio Jan 9, 2023
80d715b
Avoid applying mapping twice and assume identity when data and math t…
Nyrio Jan 12, 2023
f17fbd1
Merge remote-tracking branch 'origin/branch-23.02' into fea-kmeans-ba…
Nyrio Jan 12, 2023
7edb3b1
Merge remote-tracking branch 'origin/branch-23.02' into fea-kmeans-ba…
Nyrio Jan 17, 2023
0d04d4d
Add use-case examples to build_clusters and calc_centers_and_sizes
Nyrio Jan 17, 2023
a0c9421
Merge remote-tracking branch 'origin/branch-23.02' into fea-kmeans-ba…
Nyrio Jan 20, 2023
6d791d7
Documentation improvements + remove verbosity params (should use set_…
Nyrio Jan 23, 2023
89836d9
Merge remote-tracking branch 'origin/branch-23.02' into fea-kmeans-ba…
Nyrio Jan 23, 2023
82d1f14
Replace copy_selected with matrix::gather where possible
Nyrio Jan 23, 2023
c5aff02
Fix unsigned / signed integer comparisons
Nyrio Jan 23, 2023
8705e3a
Put build_clusters and calc_centers_and_sizes in a helpers namespace
Nyrio Jan 23, 2023
1a77a09
Change order of args in public API for consistency
Nyrio Jan 23, 2023
a6232a2
Enabling shallow copy of handle/device_resources
cjnolet Jan 24, 2023
32af7f0
Adding copy constructor to handle/device_resources that enables a dif…
cjnolet Jan 24, 2023
234f004
Merge remote-tracking branch 'cjnolet/imp-2302-resources_copy' into f…
Nyrio Jan 24, 2023
1c03a46
Don't expose memory resource in public API, use handle instead
Nyrio Jan 24, 2023
3073672
Add base class for shared k-means parameters + add explanation of the…
Nyrio Jan 24, 2023
d07bb66
De-duplicate mod_op and modulo_op
Nyrio Jan 24, 2023
d76877a
Replace modulo_op with mod_op (bis)
Nyrio Jan 24, 2023
30fd8f2
Fix changes lost in bad merge
Nyrio Jan 24, 2023
7a163fa
Doxygen fixes
Nyrio Jan 24, 2023
a25df21
Fix use of is_floating_point_v
Nyrio Jan 24, 2023
5166503
Merge remote-tracking branch 'origin/branch-23.02' into fea-kmeans-ba…
Nyrio Jan 25, 2023
7372f80
mesocluster_size_max_balanced should be IdxT, not uint32_t
Nyrio Jan 25, 2023
88eef79
For some reason my local clang-format seems to disagree with the pre-…
Nyrio Jan 25, 2023
4b41d41
Change edge case that breaks checks
Nyrio Jan 25, 2023
a829fef
Include minClusterAndDistance in batch size calculation
Nyrio Jan 25, 2023
7d399db
Merge remote-tracking branch 'origin/branch-23.02' into fea-kmeans-ba…
Nyrio Jan 26, 2023
2d78a50
Remove write_only_op helper s we can now use map_offset
Nyrio Jan 26, 2023
f262406
Replace handle_t with device_resources
Nyrio Jan 26, 2023
39e55eb
Fix missing includes
Nyrio Jan 26, 2023
4e6704c
Merge branch 'branch-23.02' into fea-kmeans-balanced-api
Nyrio Jan 27, 2023
02feb04
Merge branch 'branch-23.02' into fea-kmeans-balanced-api
cjnolet Jan 27, 2023
462fa66
Merge branch 'branch-23.02' into fea-kmeans-balanced-api
cjnolet Jan 27, 2023
6c73586
Add test with more rows to test memory footprint to some extent
Nyrio Jan 30, 2023
55fcb1f
Merge branch 'branch-23.02' into fea-kmeans-balanced-api
cjnolet Jan 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 10 additions & 18 deletions cpp/bench/cluster/kmeans_balanced.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,19 @@
*/

#include <common/benchmark.hpp>
#include <raft/cluster/kmeans_balanced.cuh>
#include <raft/random/rng.cuh>
#include <raft/spatial/knn/detail/ann_kmeans_balanced.cuh>

#if defined RAFT_DISTANCE_COMPILED && defined RAFT_NN_COMPILED
#include <raft/cluster/specializations.cuh>
#if defined RAFT_DISTANCE_COMPILED
#include <raft/distance/specializations.cuh>
#endif

namespace raft::bench::cluster {

struct KMeansBalancedBenchParams {
DatasetParams data;
uint32_t max_iter;
uint32_t n_lists;
raft::distance::DistanceType metric;
raft::cluster::kmeans_balanced_params kb_params;
};

template <typename T, typename IndexT = int>
Expand All @@ -38,15 +37,10 @@ struct KMeansBalanced : public fixture {
void run_benchmark(::benchmark::State& state) override
{
this->loop_on_state(state, [this]() {
raft::spatial::knn::detail::kmeans::build_hierarchical<T>(this->handle,
this->params.max_iter,
(uint32_t)this->params.data.cols,
this->X.data_handle(),
this->params.data.rows,
this->centroids.data_handle(),
this->params.n_lists,
this->params.metric,
this->handle.get_stream());
raft::device_matrix_view<const T, IndexT> X_view = this->X.view();
raft::device_matrix_view<T, IndexT> centroids_view = this->centroids.view();
raft::cluster::kmeans_balanced::fit(
this->handle, this->params.kb_params, X_view, centroids_view);
});
}

Expand Down Expand Up @@ -84,8 +78,8 @@ std::vector<KMeansBalancedBenchParams> getKMeansBalancedInputs()
std::vector<KMeansBalancedBenchParams> out;
KMeansBalancedBenchParams p;
p.data.row_major = true;
p.max_iter = 20;
p.metric = raft::distance::DistanceType::L2Expanded;
p.kb_params.n_iters = 20;
p.kb_params.metric = raft::distance::DistanceType::L2Expanded;
std::vector<std::pair<int, int>> row_cols = {
{100000, 128}, {1000000, 128}, {10000000, 128},
// The following dataset sizes are too large for most GPUs.
Expand All @@ -104,7 +98,5 @@ std::vector<KMeansBalancedBenchParams> getKMeansBalancedInputs()

// Note: the datasets sizes are too large for 32-bit index types.
RAFT_BENCH_REGISTER((KMeansBalanced<float, int64_t>), "", getKMeansBalancedInputs());
RAFT_BENCH_REGISTER((KMeansBalanced<int8_t, int64_t>), "", getKMeansBalancedInputs());
RAFT_BENCH_REGISTER((KMeansBalanced<uint8_t, int64_t>), "", getKMeansBalancedInputs());

} // namespace raft::bench::cluster
Loading