-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for iterating over batches in bfknn #1947
Conversation
This adds support for iterating over batches of nearest neighbors in the brute force knn. This lets you query for the nearest neighbors, and then filter down the results - and if you have filtered out too many results, get the next batch of nearest neighbors. The challenge here is balancing memory consumption versus efficieny: we could store the results of the full gemm in the distance calculation - but for large indices this discards the benefits of using the tiling strategy and risks running OOM. Instead we exponentially grow the number of neighbors being returned, and also cache both the query norms and index norms between calls.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not done an in-depth line-by-line review of this, but I wanted to provide feedback on the overall API. Short version is I really like it! I was worried about storing the resources object at first, but the implementation both makes it clear why it is necessary and limits the opportunities for that to bite the user.
For downstream integrations, we're going to want to adopt some careful patterns for ensuring that we don't e.g. modify the index part way through batched searches, but this design will integrate very cleanly with the work we're already doing in this direction.
This one gets a big +1 from me, and I can't wait to see similar capabilities for all types of indexes. Brilliant work!
} | ||
|
||
/** a single batch of nearest neighbors in device memory */ | ||
class batch { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This batch
class could be re-used for batching IVF and CAGRA searches. Can it be implemented in a separate file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even though iterators are common in C++ collections APIs, this is a new design paradigm for RAFT and implementing brute-force first was a great idea to leave a small buffer for getting the API nice and tight. A great example of how we're implementing separate interfaces and impls in RAFT even though it's header-only is in the raft::comms
namespaces (we use PIMPL paradigm to do this so the raft::resources
only has to care that a raft::comms
instance has been set.
A couple of thoughts / intentional design decisions in RAFT
- Using complex stateful objects is fine, but we should still facade them behind factory functions so that we're providing a layer of indirection between the object construction and the user. This gives us the ability to modify how the objects get constructed in the future without impacting the user.
- Trying to avoid storing
raft::resources
as state on underlying objects. This happened as a byproduct of the way we're constructing the container policies, but it was only dont with the intention of delaying the construction itself and isn't relied upon for any other state on the object. We're trying to move away from this pattern when we can avoid it. - Exposing interfaces- by convention we use
*_types.hpp
to tell the user "these things are safe to expose on your own public APIs and they won't bring in additional dependencies. If RAFT is being compiled into a user's library, think of these as the sets of headers that are safe to expose in the users owninclude/
directory. This means we should be hiding implementations for things like actual iterator types behidndetail
as much as possible. It used to be okay to expose just cuda runtime things through*_types.hpp
but we're trying to move away from that as well to support the potential that we might end up hosting Grace or other CPU-only APIs in the future.
Most of my review comments are aesthetics, though, and I'm going to take a second pass over the implementation details but things look really good at a glance.
* @tparam IdxT type of the indices in the source dataset | ||
*/ | ||
template <typename T, typename IdxT = int64_t> | ||
class batch_k_query { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please define the interface in a file called raft/neighbors/neighbors_types.hpp
(for generalized / non-brute-force types) and/or raft/neighbors/brute_force_types.hpp
, define the implementation in raft::neighbors::detail
namespace, and then create a stateless factory function like raft::neighbors::brute_force::search_batch_k()
to create an instance of it.
metric == raft::distance::DistanceType::CosineExpanded) { | ||
query_norms = make_device_vector<T, int64_t>(res, query.extent(0)); | ||
|
||
if (metric == raft::distance::DistanceType::CosineExpanded) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do this conditional in a few places in the code- perhaps we should consolidate them into a function in detail
.
void slice_current_batch(int64_t offset, int64_t batch_size) | ||
{ | ||
auto num_queries = batches.indices_.extent(0); | ||
batch_size = std::min(batch_size, query->index.size() - offset); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the batch size changes as we iterate through k?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The batch sizes can change as we iterate through K - if we want to support the redis VecSimBatchIterator
interface, we will need this to handle the getNextResults
method https://github.com/RedisAI/VectorSimilarity/blob/22954489d9184c9eba55f477463439a3532ca04e/src/VecSim/batch_iterator.h#L40-L42
cpp/test/neighbors/tiled_knn.cu
Outdated
offset = 0; | ||
int64_t batch_size = k_; | ||
batch_k_query<T, int> query(handle_, idx, query_view, batch_size); | ||
for (auto it = query.begin(); it != query.end(); it.advance(batch_size)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the user is expected to invoke advance()
, they should be able to pass in the raft::resources
instance each time, right?
Can you add some documentation and example to |
@lowener, I agree that this API should be included in the vector search tutorial, even if mentioned in the "additional features" section, but in the interest of time I think we should do a separate PR and target 24.02 with that change. We're just too close to burndown at this point. |
/merge |
This adds support for iterating over batches of nearest neighbors in the brute force knn. This lets you query for the nearest neighbors, and then filter down the results - and if you have filtered out too many results, get the next batch of nearest neighbors.
The challenge here is balancing memory consumption versus efficieny: we could store the results of the full gemm in the distance calculation - but for large indices this discards the benefits of using the tiling strategy and risks running OOM. Instead we exponentially grow the number of neighbors being returned, and also cache both the query norms and index norms between calls.