-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: filter out null values when sampling for index training #3404
fix: filter out null values when sampling for index training #3404
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3404 +/- ##
==========================================
+ Coverage 78.81% 78.85% +0.04%
==========================================
Files 250 250
Lines 91306 91475 +169
Branches 91306 91475 +169
==========================================
+ Hits 71963 72135 +172
+ Misses 16390 16379 -11
- Partials 2953 2961 +8
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
@@ -2215,6 +2222,77 @@ mod tests { | |||
.await; | |||
} | |||
|
|||
#[rstest] | |||
#[tokio::test] | |||
async fn test_create_index_nulls( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm thinking should we add some tests for verifying recall? then we can know whether flat search handles nulls well.
it might be good to modify this test https://github.com/lancedb/lance/blob/main/rust/lance/src/index/vector/ivf/v2.rs to contain half rows with nulls
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking, is there a way to count rows that are present in the index? I assume if it’s null then we don’t write it to the index file, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the test so it asserts we can use search to get all the non-null vectors back. But I am not getting the results I expect. I could use your advice to know what the expected behavior of these indices should be when there are lots of null vectors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems no such way to count that now, it could be easy for v3 index by counting the num rows of storage file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could use your advice to know what the expected behavior of these indices should be when there are lots of null vectors.
@BubbleCal Could you help me make sense of the output of this test? https://github.com/lancedb/lance/actions/runs/12918160780/job/36026117407?pr=3404
I was expecting search to only return non-null rows, but it seems like we are getting some null vectors in the results.
3393029
to
9bbd79f
Compare
rust/lance/src/index/vector/utils.rs
Outdated
// Need to filter out null values | ||
// Use a scan to collect row ids. Then sample from the row ids. Then do take. | ||
let row_addrs = dataset | ||
.scan() | ||
.filter_expr(datafusion_expr::col(column).is_not_null()) | ||
.with_row_address() | ||
.project::<&str>(&[])? | ||
.try_into_batch() | ||
.await?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@westonpace How expensive do you think this query is? This is filtering for non-null vectors and getting the row ids. Do you think there are easy optimizations we could do? If so, I'd like to capture that in a ticket.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm probably just missing something but I don't see where we are sampling.
rust/lance/src/index/vector/pq.rs
Outdated
@@ -447,6 +447,7 @@ pub async fn build_pq_model( | |||
"Finished loading training data in {:02} seconds", | |||
start.elapsed().as_secs_f32() | |||
); | |||
debug_assert_eq!(training_data.logical_null_count(), 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe just assert_eq
? This shouldn't be a critical section. Better safe than sorry.
rust/lance/src/index/vector/utils.rs
Outdated
// Use a scan to collect row ids. Then sample from the row ids. Then do take. | ||
let row_addrs = dataset | ||
.scan() | ||
.filter_expr(datafusion_expr::col(column).is_not_null()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What am I missing? At the moment there is no cheap way to scan if a column is/is not null. So this filter will load the entire column into memory? Why do scan->filter->take and not just scan->filter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, forgot the sampling bit. It sounds like the best thing for now is scan->filter
+ reservoir sampling?
let projection = dataset.schema().project(&[column])?; | ||
let batch = dataset.sample(sample_size_hint, &projection).await?; | ||
info!( | ||
"Sample training data: retrieved {} rows by sampling", | ||
batch.num_rows() | ||
); | ||
batch | ||
} else if num_rows > sample_size_hint && is_nullable { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm...shouldn't you be using sample_size_hint
? In this branch there are way more rows than we need. E.g. to train a dataset with 1B rows we need 30K partitions and so sample_size_hint
will be ~8M. It looks like you're going to read all 1B vectors. Also, I don't see any randomization.
Did you mean to shuffle row_addrs
?
FWIW, in the python, we do this (
lance/python/python/lance/sampler.py
Line 137 in 58c5e27
for shard in shards: |
- Create a randomized take stream that will eventually take the entire dataset
- Pull from the take stream and filter out nulls in-memory until we have
sample_size_hint
rows. - Stop pulling from the take stream
To speed up the "randomized take stream" we actually stream random "contiguous shards" that are sized to give us at least 2K take operations if there are no nulls (IIRC)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, random_ranges
looks useful. Maybe we can simplify the python impls at some point in the future.
We were not filtering out null values when sampling. Because we often call
array.values()
on Arrow arrays, which ignores the null bitmap, we are often silently treating the nulls as zeros (or possibly undefined values). Only thing that caught these nulls is an assertion. However, residualization occurring with L2 and Cosine often meant that these values were transformed and null information was lost before the assertion, which is why it got past previous unit tests.This PR adds more assertions validating there aren't nulls, and makes sure the sampling code handles null vectors.
Closes #3402
Closes #3400