-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add index class for brute_force knn #1817
Conversation
This adds an index class to match the ANN methods. This allows us to precompute norms for the dataset in `brute_force::build` and then use them in `brute_force::search` - meaning we don't have to compute norms for the entire dataset on every query.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Just a few minor things.
*/ | ||
template <typename T, | ||
typename Accessor = | ||
host_device_accessor<std::experimental::default_accessor<T>, memory_type::host>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect defaulting this is going to cause some unexpected behaviors down the line. We should probably not default this. Though it is great to see that we can now support datasets on host!
raft::linalg::NormType::L2Norm, | ||
raft::linalg::Apply::ALONG_ROWS); | ||
} | ||
ret.update_norms(std::move(norms)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To reduce the amount of mutability, can we just accept these in the constructor as a std::optional?
Most users are going to be using raft::neighbors::brute_force::build() to construct the index instead of constructing it directly, so it doesn't really impose any additional burden on the user.
@@ -38,6 +39,21 @@ inline void knn_merge_parts( | |||
size_t n_samples, | |||
std::optional<raft::device_vector_view<idx_t, idx_t>> translations = std::nullopt) RAFT_EXPLICIT; | |||
|
|||
template <typename T, | |||
typename Accessor = | |||
host_device_accessor<std::experimental::default_accessor<T>, memory_type::host>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just let the tile get inferred? I don't think we want to default this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
/merge |
/merge |
This adds an index class to match the ANN methods. This allows us to precompute norms for the dataset in
brute_force::build
and then use them inbrute_force::search
- meaning we don't have to compute norms for the entire dataset on every query.