Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve multi-CTA algorithm #492

Open
wants to merge 19 commits into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 53 additions & 9 deletions cpp/src/neighbors/detail/cagra/add_nodes.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,31 @@ void add_node_core(
raft::resource::get_cuda_stream(handle));
raft::resource::sync_stream(handle);

// Check search results
constexpr int max_warnings = 3;
int num_warnings = 0;
for (std::size_t vec_i = 0; vec_i < batch.size(); vec_i++) {
std::uint32_t invalid_edges = 0;
for (std::uint32_t i = 0; i < base_degree; i++) {
if (host_neighbor_indices(vec_i, i) >= old_size) { invalid_edges++; }
}
if (invalid_edges > 0) {
if (num_warnings < max_warnings) {
RAFT_LOG_WARN(
"Invalid edges found in search results "
"(vec_i:%lu, invalid_edges:%lu, degree:%lu, base_degree:%lu)",
(uint64_t)vec_i,
(uint64_t)invalid_edges,
(uint64_t)degree,
(uint64_t)base_degree);
}
num_warnings += 1;
}
}
if (num_warnings > max_warnings) {
RAFT_LOG_WARN("The number of queries that contain invalid search results: %d", num_warnings);
}
anaruse marked this conversation as resolved.
Show resolved Hide resolved

// Step 2: rank-based reordering
#pragma omp parallel
{
Expand All @@ -134,9 +159,16 @@ void add_node_core(
for (std::uint32_t i = 0; i < base_degree; i++) {
std::uint32_t detourable_node_count = 0;
const auto a_id = host_neighbor_indices(vec_i, i);
if (a_id >= idx.size()) {
// If the node ID is not valid, the number of detours is increased
// to a value greater than the maximum, so that the edge to that
// node is not selected as much as possible.
detourable_node_count_list[i] = std::make_pair(a_id, base_degree + 1);
anaruse marked this conversation as resolved.
Show resolved Hide resolved
continue;
}
for (std::uint32_t j = 0; j < i; j++) {
const auto b0_id = host_neighbor_indices(vec_i, j);
assert(b0_id < idx.size());
if (b0_id >= idx.size()) { continue; }
for (std::uint32_t k = 0; k < degree; k++) {
const auto b1_id = updated_graph(b0_id, k);
if (a_id == b1_id) {
Expand All @@ -147,6 +179,7 @@ void add_node_core(
}
detourable_node_count_list[i] = std::make_pair(a_id, detourable_node_count);
}

std::sort(detourable_node_count_list.begin(),
detourable_node_count_list.end(),
[&](const std::pair<IdxT, std::size_t> a, const std::pair<IdxT, std::size_t> b) {
Expand All @@ -168,13 +201,18 @@ void add_node_core(
const auto target_new_node_id = old_size + batch.offset() + vec_i;
for (std::size_t i = 0; i < num_rev_edges; i++) {
const auto target_node_id = updated_graph(old_size + batch.offset() + vec_i, i);

if (target_node_id >= new_size) {
RAFT_FAIL("Invalid node ID found in updated_graph (%u)\n", target_node_id);
}
IdxT replace_id = new_size;
IdxT replace_id_j = 0;
std::size_t replace_num_incoming_edges = 0;
for (std::int32_t j = degree - 1; j >= static_cast<std::int32_t>(rev_edge_search_range);
j--) {
const auto neighbor_id = updated_graph(target_node_id, j);
const auto neighbor_id = updated_graph(target_node_id, j);
if (neighbor_id >= new_size) {
RAFT_FAIL("Invalid node ID found in updated_graph (%u)\n", neighbor_id);
}
const std::size_t num_incoming_edges = host_num_incoming_edges(neighbor_id);
if (num_incoming_edges > replace_num_incoming_edges) {
// Check duplication
Expand All @@ -193,10 +231,6 @@ void add_node_core(
replace_id_j = j;
}
}
if (replace_id >= new_size) {
std::fprintf(stderr, "Invalid rev edge index (%u)\n", replace_id);
return;
}
updated_graph(target_node_id, replace_id_j) = target_new_node_id;
rev_edges[i] = replace_id;
}
Expand All @@ -208,13 +242,15 @@ void add_node_core(
const auto rank_based_list_ptr =
updated_graph.data_handle() + (old_size + batch.offset() + vec_i) * degree;
const auto rev_edges_return_list_ptr = rev_edges.data();
while (num_add < degree) {
while ((num_add < degree) &&
((rank_base_i < degree) || (rev_edges_return_i < num_rev_edges))) {
const auto node_list_ptr =
interleave_switch == 0 ? rank_based_list_ptr : rev_edges_return_list_ptr;
auto& node_list_index = interleave_switch == 0 ? rank_base_i : rev_edges_return_i;
const auto max_node_list_index = interleave_switch == 0 ? degree : num_rev_edges;
for (; node_list_index < max_node_list_index; node_list_index++) {
const auto candidate = node_list_ptr[node_list_index];
if (candidate >= new_size) { continue; }
// Check duplication
bool dup = false;
for (std::uint32_t j = 0; j < num_add; j++) {
Expand All @@ -231,6 +267,12 @@ void add_node_core(
}
interleave_switch = 1 - interleave_switch;
}
if (num_add < degree) {
RAFT_FAIL("Number of edges is not enough (target_new_node_id:%u, num_add:%u, degree:%u)",
target_new_node_id,
num_add,
degree);
}
for (std::uint32_t i = 0; i < degree; i++) {
updated_graph(target_new_node_id, i) = temp[i];
}
Expand All @@ -246,7 +288,9 @@ void add_graph_nodes(
raft::host_matrix_view<IdxT, std::int64_t> updated_graph_view,
const cagra::extend_params& params)
{
assert(input_updated_dataset_view.extent(0) >= index.size());
if (input_updated_dataset_view.extent(0) < index.size()) {
RAFT_FAIL("Updated dataset must be not smaller than the previous index state.");
}

const std::size_t initial_dataset_size = index.size();
const std::size_t new_dataset_size = input_updated_dataset_view.extent(0);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void search_main_core(raft::resources const& res,
using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type;
std::unique_ptr<search_plan_impl<DataT, IndexT, DistanceT, CagraSampleFilterT_s>> plan =
factory<DataT, IndexT, DistanceT, CagraSampleFilterT_s>::create(
res, params, dataset_desc, queries.extent(1), graph.extent(1), topk);
res, params, dataset_desc, queries.extent(1), graph.extent(0), graph.extent(1), topk);

plan->check(topk);

Expand Down
64 changes: 47 additions & 17 deletions cpp/src/neighbors/detail/cagra/device_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(
const IndexT* __restrict__ seed_ptr, // [num_seeds]
const uint32_t num_seeds,
IndexT* __restrict__ visited_hash_ptr,
const uint32_t hash_bitlen,
const uint32_t visited_hash_bitlen,
IndexT* __restrict__ traversed_hash_ptr,
const uint32_t traversed_hash_bitlen,
const uint32_t block_id = 0,
const uint32_t num_blocks = 1)
{
Expand Down Expand Up @@ -145,19 +147,29 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(

const unsigned lane_id = threadIdx.x & ((1u << team_size_bits) - 1u);
if (valid_i && lane_id == 0) {
if (best_index_team_local != raft::upper_bound<IndexT>() &&
hashmap::insert(visited_hash_ptr, hash_bitlen, best_index_team_local)) {
result_distances_ptr[i] = best_norm2_team_local;
result_indices_ptr[i] = best_index_team_local;
} else {
result_distances_ptr[i] = raft::upper_bound<DistanceT>();
result_indices_ptr[i] = raft::upper_bound<IndexT>();
if (best_index_team_local != raft::upper_bound<IndexT>()) {
if (hashmap::insert(visited_hash_ptr, visited_hash_bitlen, best_index_team_local) == 0) {
// Deactivate this entry as insertion into visited hash table has failed.
best_norm2_team_local = raft::upper_bound<DistanceT>();
best_index_team_local = raft::upper_bound<IndexT>();
} else if ((traversed_hash_ptr != nullptr) &&
hashmap::search<IndexT, 1>(
traversed_hash_ptr, traversed_hash_bitlen, best_index_team_local)) {
// Deactivate this entry as it has been already used by otehrs.
anaruse marked this conversation as resolved.
Show resolved Hide resolved
best_norm2_team_local = raft::upper_bound<DistanceT>();
best_index_team_local = raft::upper_bound<IndexT>();
}
}
result_distances_ptr[i] = best_norm2_team_local;
result_indices_ptr[i] = best_index_team_local;
}
}
}

template <typename IndexT, typename DistanceT, typename DATASET_DESCRIPTOR_T>
template <typename IndexT,
typename DistanceT,
typename DATASET_DESCRIPTOR_T,
int STATIC_RESULT_POSITION = 1>
RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
IndexT* __restrict__ result_child_indices_ptr,
DistanceT* __restrict__ result_child_distances_ptr,
Expand All @@ -168,13 +180,17 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
const uint32_t knn_k,
// hashmap
IndexT* __restrict__ visited_hashmap_ptr,
const uint32_t hash_bitlen,
const uint32_t visited_hash_bitlen,
IndexT* __restrict__ traversed_hashmap_ptr,
const uint32_t traversed_hash_bitlen,
const IndexT* __restrict__ parent_indices,
const IndexT* __restrict__ internal_topk_list,
const uint32_t search_width)
const uint32_t search_width,
int* __restrict__ result_position = nullptr,
const int max_result_position = 0)
{
constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask<IndexT>::value;
constexpr IndexT invalid_index = raft::upper_bound<IndexT>();
constexpr IndexT invalid_index = ~static_cast<IndexT>(0);

// Read child indices of parents from knn graph and check if the distance
// computaiton is necessary.
Expand All @@ -186,11 +202,22 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
child_id = knn_graph[(i % knn_k) + (static_cast<int64_t>(knn_k) * parent_id)];
}
if (child_id != invalid_index) {
if (hashmap::insert(visited_hashmap_ptr, hash_bitlen, child_id) == 0) {
if (hashmap::insert(visited_hashmap_ptr, visited_hash_bitlen, child_id) == 0) {
// Deactivate this entry as insertion into visited hash table has failed.
child_id = invalid_index;
} else if ((traversed_hashmap_ptr != nullptr) &&
hashmap::search<IndexT, 1>(
traversed_hashmap_ptr, traversed_hash_bitlen, child_id)) {
// Deactivate this entry as this has been already used by others.
child_id = invalid_index;
}
}
result_child_indices_ptr[i] = child_id;
if (STATIC_RESULT_POSITION) {
result_child_indices_ptr[i] = child_id;
} else if (child_id != invalid_index) {
int j = atomicSub(result_position, 1) - 1;
result_child_indices_ptr[j] = child_id;
}
}
__syncthreads();

Expand All @@ -201,9 +228,11 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
const auto compute_distance = dataset_desc.compute_distance_impl;
const auto args = dataset_desc.args.load();
const bool lead_lane = (threadIdx.x & ((1u << team_size_bits) - 1u)) == 0;
const uint32_t ofst = STATIC_RESULT_POSITION ? 0 : result_position[0];
for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += blockDim.x >> team_size_bits) {
const bool valid_i = i < num_k;
const auto child_id = valid_i ? result_child_indices_ptr[i] : invalid_index;
const auto j = i + ofst;
const bool valid_i = STATIC_RESULT_POSITION ? (j < num_k) : (j < max_result_position);
const auto child_id = valid_i ? result_child_indices_ptr[j] : invalid_index;

// We should be calling `dataset_desc.compute_distance(..)` here as follows:
// > const auto child_dist = dataset_desc.compute_distance(child_id, child_id != invalid_index);
Expand All @@ -213,9 +242,10 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
(child_id != invalid_index) ? compute_distance(args, child_id)
: (lead_lane ? raft::upper_bound<DistanceT>() : 0),
team_size_bits);
__syncwarp();

// Store the distance
if (valid_i && lead_lane) { result_child_distances_ptr[i] = child_dist; }
if (valid_i && lead_lane) { result_child_distances_ptr[j] = child_dist; }
}
}

Expand Down
9 changes: 5 additions & 4 deletions cpp/src/neighbors/detail/cagra/factory.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ class factory {
search_params const& params,
const dataset_descriptor_host<DataT, IndexT, DistanceT>& dataset_desc,
int64_t dim,
int64_t dataset_size,
int64_t graph_degree,
uint32_t topk)
{
search_plan_impl_base plan(params, dim, graph_degree, topk);
search_plan_impl_base plan(params, dim, dataset_size, graph_degree, topk);
return dispatch_kernel(res, plan, dataset_desc);
}

Expand All @@ -56,15 +57,15 @@ class factory {
if (plan.algo == search_algo::SINGLE_CTA) {
return std::make_unique<
single_cta_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
} else if (plan.algo == search_algo::MULTI_CTA) {
return std::make_unique<
multi_cta_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
} else {
return std::make_unique<
multi_kernel_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
}
}
};
Expand Down
Loading
Loading