-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove MetricProcessor code from brute_force::knn #1426
Remove MetricProcessor code from brute_force::knn #1426
Conversation
Stop using the MetricProcessor code to preprocess the inputs to the bfknn calls. Since the pairwise distance API supports both cosine and correlation distance, this wasn't required anymore - and it introduced NaN values to the input when passed a dataset with one of the rows being all zero.
search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream); | ||
raft::linalg::rowNorm( | ||
index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream); | ||
// cosine needs the l2norm, where as l2 distances needs the squared norm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you also need to do this for correlation
since it's a normalized cosine or does that distance "just work" in the pw dists?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be a good idea to do this for correlation distance too for performance reasons - though the PW distance api does 'just work', and this change already improves the times quite a bit.
times on 23.06 branch (on a github dataset)
In [10]: %timeit brute_force.knn(repo_embeddings, repo_embeddings[repoids], k=10, metric="cosine")
56.6 ms ± 6.48 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [11]: %timeit brute_force.knn(repo_embeddings, repo_embeddings[repoids], k=10, metric="correlation")
82.8 ms ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
times on this branch:
In [3]: %timeit brute_force.knn(repo_embeddings, repo_embeddings[repoids], k=10, metric="cosine")
...:
26.9 ms ± 32.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [4]: %timeit brute_force.knn(repo_embeddings, repo_embeddings[repoids], k=10, metric="correlation")
49.5 ms ± 51 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
cosine/correlation times are both substantially improved with this change - though it might be worth expanding the correlation distance in the same way as we're doing with cosine/l2 etc in a future PR.
Note that in the demo wednesdays presentation - I was seeing 6.5ms for the l2 metric
In [12]: %timeit brute_force.knn(repo_embeddings, repo_embeddings[repoids], k=10, metric="l2")
6.51 ms ± 14.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
However, this is because its hitting the fused l2 code - which is substantially faster in this case. The good news is that most of this perf difference is in the select_k
call, which I'm trying to get sped up now w/ changes like #1430. I believe we can get the cosine etc code up to around ~9ms on this call total - and the perf here is suffering since we're using the faiss select which does poorly on a single row.
/merge |
/merge |
Stop using the MetricProcessor code to preprocess the inputs to the bfknn calls. Since the pairwise distance API supports both cosine and correlation distance, this wasn't required anymore - and it introduced NaN values to the input when passed a dataset with one of the rows being all zero. Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai#1426
Stop using the MetricProcessor code to preprocess the inputs to the bfknn calls. Since the pairwise distance API supports both cosine and correlation distance, this wasn't required anymore - and it introduced NaN values to the input when passed a dataset with one of the rows being all zero.