From e6346aba39fd7f0f912243898415d1acb9558de0 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 2 Jan 2025 16:28:59 +0800 Subject: [PATCH 1/5] feat: vector search with distance range Signed-off-by: BubbleCal --- java/core/lance-jni/src/utils.rs | 2 + rust/lance-index/src/vector.rs | 6 + rust/lance-index/src/vector/flat/index.rs | 18 ++- rust/lance/src/dataset/scanner.rs | 63 ++++++++- rust/lance/src/index/vector/fixture_test.rs | 2 + rust/lance/src/index/vector/ivf.rs | 2 + rust/lance/src/index/vector/ivf/v2.rs | 135 +++++++++++++++++++- rust/lance/src/index/vector/pq.rs | 54 ++++++-- 8 files changed, 262 insertions(+), 20 deletions(-) diff --git a/java/core/lance-jni/src/utils.rs b/java/core/lance-jni/src/utils.rs index 742bff1742..4a2d4ae529 100644 --- a/java/core/lance-jni/src/utils.rs +++ b/java/core/lance-jni/src/utils.rs @@ -118,6 +118,8 @@ pub fn get_query(env: &mut JNIEnv, query_obj: JObject) -> Result> column, key, k, + lower_bound: None, + upper_bound: None, nprobes, ef, refine_factor, diff --git a/rust/lance-index/src/vector.rs b/rust/lance-index/src/vector.rs index cff976dcd3..22418ef65c 100644 --- a/rust/lance-index/src/vector.rs +++ b/rust/lance-index/src/vector.rs @@ -66,6 +66,12 @@ pub struct Query { /// Top k results to return. pub k: usize, + /// The lower bound (inclusive) of the distance to be searched. + pub lower_bound: Option, + + /// The upper bound (exclusive) of the distance to be searched. + pub upper_bound: Option, + /// The number of probes to load and search. pub nprobes: usize, diff --git a/rust/lance-index/src/vector/flat/index.rs b/rust/lance-index/src/vector/flat/index.rs index 297bf115c0..f85ef30803 100644 --- a/rust/lance-index/src/vector/flat/index.rs +++ b/rust/lance-index/src/vector/flat/index.rs @@ -44,11 +44,17 @@ lazy_static::lazy_static! { } #[derive(Default)] -pub struct FlatQueryParams {} +pub struct FlatQueryParams { + lower_bound: Option, + upper_bound: Option, +} impl From<&Query> for FlatQueryParams { - fn from(_: &Query) -> Self { - Self {} + fn from(q: &Query) -> Self { + Self { + lower_bound: q.lower_bound, + upper_bound: q.upper_bound, + } } } @@ -72,7 +78,7 @@ impl IvfSubIndex for FlatIndex { &self, query: ArrayRef, k: usize, - _params: Self::QueryParams, + params: Self::QueryParams, storage: &impl VectorStore, prefilter: Arc, ) -> Result { @@ -88,6 +94,8 @@ impl IvfSubIndex for FlatIndex { dist: OrderedFloat(dist), }) .sorted_unstable() + .skip_while(|r| params.lower_bound.map_or(false, |lb| r.dist.0 < lb)) + .take_while(|r| params.upper_bound.map_or(true, |ub| r.dist.0 < ub)) .take(k) .map( |OrderedNode { @@ -105,6 +113,8 @@ impl IvfSubIndex for FlatIndex { dist: OrderedFloat(dist_calc.distance(id as u32)), }) .sorted_unstable() + .skip_while(|r| params.lower_bound.map_or(false, |lb| r.dist.0 < lb)) + .take_while(|r| params.upper_bound.map_or(true, |ub| r.dist.0 < ub)) .take(k) .map( |OrderedNode { diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 22ee289c97..4d43e2b38e 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -34,6 +34,7 @@ use datafusion::physical_plan::{ ExecutionPlan, SendableRecordBatchStream, }; use datafusion::scalar::ScalarValue; +use datafusion_expr::Operator; use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::{Partitioning, PhysicalExpr}; use futures::future::BoxFuture; @@ -705,6 +706,8 @@ impl Scanner { column: column.to_string(), key: key.into(), k, + lower_bound: None, + upper_bound: None, nprobes: 1, ef: None, refine_factor: None, @@ -714,6 +717,19 @@ impl Scanner { Ok(self) } + /// Set the distance thresholds for the nearest neighbor search. + pub fn distance_range( + &mut self, + lower_bound: Option, + upper_bound: Option, + ) -> &mut Self { + if let Some(q) = self.nearest.as_mut() { + q.lower_bound = lower_bound; + q.upper_bound = upper_bound; + } + self + } + pub fn nprobs(&mut self, n: usize) -> &mut Self { if let Some(q) = self.nearest.as_mut() { q.nprobes = n; @@ -1994,16 +2010,59 @@ impl Scanner { q.metric_type, )?); + // filter out elements out of distance range + let lower_bound_expr = q + .lower_bound + .map(|v| { + let lower_bound = expressions::lit(v); + expressions::binary( + expressions::col(DIST_COL, flat_dist.schema().as_ref())?, + Operator::GtEq, + lower_bound, + flat_dist.schema().as_ref(), + ) + }) + .transpose()?; + let upper_bound_expr = q + .upper_bound + .map(|v| { + let upper_bound = expressions::lit(v); + expressions::binary( + expressions::col(DIST_COL, flat_dist.schema().as_ref())?, + Operator::Lt, + upper_bound, + flat_dist.schema().as_ref(), + ) + }) + .transpose()?; + let filter_expr = match (lower_bound_expr, upper_bound_expr) { + (Some(lower), Some(upper)) => Some(expressions::binary( + lower, + Operator::And, + upper, + flat_dist.schema().as_ref(), + )?), + (Some(lower), None) => Some(lower), + (None, Some(upper)) => Some(upper), + (None, None) => None, + }; + + let knn_plan: Arc = if let Some(filter_expr) = filter_expr { + Arc::new(FilterExec::try_new(filter_expr, flat_dist)?) + } else { + flat_dist + }; + // Use DataFusion's [SortExec] for Top-K search let sort = SortExec::new( vec![PhysicalSortExpr { - expr: expressions::col(DIST_COL, flat_dist.schema().as_ref())?, + expr: expressions::col(DIST_COL, knn_plan.schema().as_ref())?, options: SortOptions { descending: false, nulls_first: false, }, }], - flat_dist, + knn_plan, ) .with_fetch(Some(q.k)); diff --git a/rust/lance/src/index/vector/fixture_test.rs b/rust/lance/src/index/vector/fixture_test.rs index 274f0e4493..7d3342c623 100644 --- a/rust/lance/src/index/vector/fixture_test.rs +++ b/rust/lance/src/index/vector/fixture_test.rs @@ -233,6 +233,8 @@ mod test { column: "test".to_string(), key: Arc::new(Float32Array::from(query)), k: 1, + lower_bound: None, + upper_bound: None, nprobes: 1, ef: None, refine_factor: None, diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 19f4bb7e01..d9e5db629f 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -1978,6 +1978,8 @@ mod tests { column: Self::COLUMN.to_string(), key: Arc::new(row), k: 5, + lower_bound: None, + upper_bound: None, nprobes: 1, ef: None, refine_factor: None, diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index c6a567efb1..05de877da6 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -518,6 +518,7 @@ mod tests { use std::collections::HashSet; use std::{collections::HashMap, ops::Range, sync::Arc}; + use all_asserts::{assert_ge, assert_lt}; use arrow::datatypes::{UInt64Type, UInt8Type}; use arrow::{array::AsArray, datatypes::Float32Type}; use arrow_array::{ @@ -614,7 +615,7 @@ mod tests { async fn test_index(params: VectorIndexParams, nlist: usize, recall_requirement: f32) { match params.metric_type { DistanceType::Hamming => { - test_index_impl::(params, nlist, recall_requirement, 0..2).await; + test_index_impl::(params, nlist, recall_requirement, 0..255).await; } _ => { test_index_impl::(params, nlist, recall_requirement, 0.0..1.0).await; @@ -746,6 +747,11 @@ mod tests { }); } + #[tokio::test] + async fn test_flat_knn() { + test_distance_range(None, 4).await; + } + #[rstest] #[case(4, DistanceType::L2, 1.0)] #[case(4, DistanceType::Cosine, 1.0)] @@ -759,6 +765,7 @@ mod tests { ) { let params = VectorIndexParams::ivf_flat(nlist, distance_type); test_index(params.clone(), nlist, recall_requirement).await; + test_distance_range(Some(params.clone()), nlist).await; test_remap(params, nlist).await; } @@ -776,6 +783,7 @@ mod tests { let pq_params = PQBuildParams::default(); let params = VectorIndexParams::with_ivf_pq_params(distance_type, ivf_params, pq_params); test_index(params.clone(), nlist, recall_requirement).await; + test_distance_range(Some(params.clone()), nlist).await; test_remap(params, nlist).await; } @@ -795,6 +803,7 @@ mod tests { .version(crate::index::vector::IndexFileVersion::V3) .clone(); test_index(params.clone(), nlist, recall_requirement).await; + test_distance_range(Some(params.clone()), nlist).await; test_remap(params, nlist).await; } @@ -814,6 +823,7 @@ mod tests { .version(crate::index::vector::IndexFileVersion::V3) .clone(); test_index(params.clone(), nlist, recall_requirement).await; + test_distance_range(Some(params.clone()), nlist).await; test_remap(params, nlist).await; } @@ -989,4 +999,127 @@ mod tests { assert_eq!(index["sub_index"]["index_type"].as_str().unwrap(), "HNSW"); } } + + async fn test_distance_range(params: Option, nlist: usize) { + match params.as_ref().map_or(DistanceType::L2, |p| p.metric_type) { + DistanceType::Hamming => { + test_distance_range_impl::(params, nlist, 0..255).await; + } + _ => { + test_distance_range_impl::(params, nlist, 0.0..1.0).await; + } + } + } + + async fn test_distance_range_impl( + params: Option, + nlist: usize, + range: Range, + ) where + T::Native: SampleUniform, + { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let (mut dataset, vectors) = generate_test_dataset::(test_uri, range).await; + + let vector_column = "vector"; + let dist_type = params.as_ref().map_or(DistanceType::L2, |p| p.metric_type); + if let Some(params) = params { + dataset + .create_index(&[vector_column], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + } + + let query = vectors.value(0); + let k = 100; + let result = dataset + .scan() + .nearest(vector_column, query.as_primitive::(), k) + .unwrap() + .nprobs(nlist) + .with_row_id() + .try_into_batch() + .await + .unwrap(); + assert_eq!(result.num_rows(), k); + let row_ids = result[ROW_ID].as_primitive::().values(); + let dists = result[DIST_COL].as_primitive::().values(); + + let part_idx = k / 2; + let part_dist = dists[part_idx]; + + let left_res = dataset + .scan() + .nearest(vector_column, query.as_primitive::(), part_idx) + .unwrap() + .nprobs(nlist) + .with_row_id() + .distance_range(None, Some(part_dist)) + .try_into_batch() + .await + .unwrap(); + let right_res = dataset + .scan() + .nearest(vector_column, query.as_primitive::(), k - part_idx) + .unwrap() + .nprobs(nlist) + .with_row_id() + .distance_range(Some(part_dist), None) + .try_into_batch() + .await + .unwrap(); + // don't verify the number of results and row ids for hamming distance, + // because there are many vectors with the same distance + if dist_type != DistanceType::Hamming { + assert_eq!(left_res.num_rows(), part_idx); + assert_eq!(right_res.num_rows(), k - part_idx); + let left_row_ids = left_res[ROW_ID].as_primitive::().values(); + let right_row_ids = right_res[ROW_ID].as_primitive::().values(); + row_ids.iter().enumerate().for_each(|(i, id)| { + if i < part_idx { + assert_eq!(left_row_ids[i], *id); + } else { + assert_eq!(right_row_ids[i - part_idx], *id); + } + }); + } + let left_dists = left_res[DIST_COL].as_primitive::().values(); + let right_dists = right_res[DIST_COL].as_primitive::().values(); + left_dists.iter().for_each(|d| { + assert!(d < &part_dist); + }); + right_dists.iter().for_each(|d| { + assert!(d >= &part_dist); + }); + + let exclude_last_res = dataset + .scan() + .nearest(vector_column, query.as_primitive::(), k) + .unwrap() + .nprobs(nlist) + .with_row_id() + .distance_range(dists.first().copied(), dists.last().copied()) + .try_into_batch() + .await + .unwrap(); + if dist_type != DistanceType::Hamming { + assert_eq!(exclude_last_res.num_rows(), k - 1); + let res_row_ids = exclude_last_res[ROW_ID] + .as_primitive::() + .values(); + row_ids.iter().enumerate().for_each(|(i, id)| { + if i < k - 1 { + assert_eq!(res_row_ids[i], *id); + } + }); + } + let res_dists = exclude_last_res[DIST_COL] + .as_primitive::() + .values(); + res_dists.iter().for_each(|d| { + assert_ge!(*d, dists[0]); + assert_lt!(*d, dists[k - 1]); + }); + } } diff --git a/rust/lance/src/index/vector/pq.rs b/rust/lance/src/index/vector/pq.rs index dc2de4c91a..03f86a7fdc 100644 --- a/rust/lance/src/index/vector/pq.rs +++ b/rust/lance/src/index/vector/pq.rs @@ -5,25 +5,25 @@ use std::sync::Arc; use std::{any::Any, collections::HashMap}; use arrow::compute::concat; -use arrow_array::UInt32Array; use arrow_array::{ cast::{as_primitive_array, AsArray}, Array, FixedSizeListArray, RecordBatch, UInt64Array, UInt8Array, }; +use arrow_array::{Float32Array, UInt32Array}; use arrow_ord::sort::sort_to_indices; -use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; +use arrow_schema::DataType; use arrow_select::take::take; use async_trait::async_trait; use deepsize::DeepSizeOf; +use lance_core::utils::address::RowAddress; use lance_core::utils::tokio::spawn_cpu; use lance_core::ROW_ID; -use lance_core::{utils::address::RowAddress, ROW_ID_FIELD}; use lance_index::vector::ivf::storage::IvfModel; use lance_index::vector::pq::storage::{transpose, ProductQuantizationStorage}; use lance_index::vector::quantizer::{Quantization, QuantizationType, Quantizer}; use lance_index::vector::v3::subindex::SubIndexType; use lance_index::{ - vector::{pq::ProductQuantizer, Query, DIST_COL}, + vector::{pq::ProductQuantizer, Query}, Index, IndexType, }; use lance_io::{traits::Reader, utils::read_fixed_stride_array}; @@ -41,6 +41,7 @@ use lance_linalg::kernels::normalize_fsl; use super::VectorIndex; use crate::index::prefilter::PreFilter; use crate::index::vector::utils::maybe_sample_training_data; +use crate::io::exec::knn::KNN_INDEX_SCHEMA; use crate::{arrow::*, Dataset}; use crate::{Error, Result}; @@ -226,15 +227,42 @@ impl VectorIndex for PQIndex { debug_assert_eq!(distances.len(), row_ids.len()); let limit = query.k * query.refine_factor.unwrap_or(1) as usize; - let indices = sort_to_indices(&distances, None, Some(limit))?; - let distances = take(&distances, &indices, None)?; - let row_ids = take(row_ids.as_ref(), &indices, None)?; - - let schema = Arc::new(ArrowSchema::new(vec![ - ArrowField::new(DIST_COL, DataType::Float32, true), - ROW_ID_FIELD.clone(), - ])); - Ok(RecordBatch::try_new(schema, vec![distances, row_ids])?) + if query.lower_bound.is_none() && query.upper_bound.is_none() { + let indices = sort_to_indices(&distances, None, Some(limit))?; + let distances = take(&distances, &indices, None)?; + let row_ids = take(row_ids.as_ref(), &indices, None)?; + Ok(RecordBatch::try_new( + KNN_INDEX_SCHEMA.clone(), + vec![distances, row_ids], + )?) + } else { + let indices = sort_to_indices(&distances, None, Some(limit))?; + let mut dists = Vec::with_capacity(limit); + let mut ids = Vec::with_capacity(limit); + for idx in indices.values().iter() { + let dist = distances.value(*idx as usize); + let id = row_ids.value(*idx as usize); + if query.lower_bound.map_or(false, |lb| dist < lb) { + continue; + } + if query.upper_bound.map_or(false, |ub| dist >= ub) { + break; + } + + dists.push(dist); + ids.push(id); + + if dists.len() >= limit { + break; + } + } + let dists = Arc::new(Float32Array::from(dists)); + let ids = Arc::new(UInt64Array::from(ids)); + Ok(RecordBatch::try_new( + KNN_INDEX_SCHEMA.clone(), + vec![dists, ids], + )?) + } }) .await } From 5029eecdfd6b4cc5fdd39d9cbae61fc029c7be2b Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 2 Jan 2025 17:11:22 +0800 Subject: [PATCH 2/5] fix Signed-off-by: BubbleCal --- rust/lance/src/index/vector/ivf/v2.rs | 17 ++++++++--------- rust/lance/src/index/vector/pq.rs | 2 +- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index 05de877da6..c1c19bd6e4 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -772,7 +772,7 @@ mod tests { #[rstest] #[case(4, DistanceType::L2, 0.9)] #[case(4, DistanceType::Cosine, 0.9)] - #[case(4, DistanceType::Dot, 0.9)] + #[case(4, DistanceType::Dot, 0.85)] #[tokio::test] async fn test_build_ivf_pq( #[case] nlist: usize, @@ -790,7 +790,7 @@ mod tests { #[rstest] #[case(4, DistanceType::L2, 0.9)] #[case(4, DistanceType::Cosine, 0.9)] - #[case(4, DistanceType::Dot, 0.9)] + #[case(4, DistanceType::Dot, 0.85)] #[tokio::test] async fn test_build_ivf_pq_v3( #[case] nlist: usize, @@ -808,8 +808,8 @@ mod tests { } #[rstest] - #[case(4, DistanceType::L2, 0.9)] - #[case(4, DistanceType::Cosine, 0.9)] + #[case(4, DistanceType::L2, 0.85)] + #[case(4, DistanceType::Cosine, 0.85)] #[case(4, DistanceType::Dot, 0.8)] #[tokio::test] async fn test_build_ivf_pq_4bit( @@ -823,14 +823,13 @@ mod tests { .version(crate::index::vector::IndexFileVersion::V3) .clone(); test_index(params.clone(), nlist, recall_requirement).await; - test_distance_range(Some(params.clone()), nlist).await; test_remap(params, nlist).await; } #[rstest] #[case(4, DistanceType::L2, 0.9)] #[case(4, DistanceType::Cosine, 0.9)] - #[case(4, DistanceType::Dot, 0.9)] + #[case(4, DistanceType::Dot, 0.85)] #[tokio::test] async fn test_create_ivf_hnsw_sq( #[case] nlist: usize, @@ -852,7 +851,7 @@ mod tests { #[rstest] #[case(4, DistanceType::L2, 0.9)] #[case(4, DistanceType::Cosine, 0.9)] - #[case(4, DistanceType::Dot, 0.9)] + #[case(4, DistanceType::Dot, 0.85)] #[tokio::test] async fn test_create_ivf_hnsw_pq( #[case] nlist: usize, @@ -1032,7 +1031,7 @@ mod tests { } let query = vectors.value(0); - let k = 100; + let k = 10; let result = dataset .scan() .nearest(vector_column, query.as_primitive::(), k) @@ -1080,7 +1079,7 @@ mod tests { if i < part_idx { assert_eq!(left_row_ids[i], *id); } else { - assert_eq!(right_row_ids[i - part_idx], *id); + assert_eq!(right_row_ids[i - part_idx], *id, "{:?}", right_row_ids); } }); } diff --git a/rust/lance/src/index/vector/pq.rs b/rust/lance/src/index/vector/pq.rs index 03f86a7fdc..3aa7568b20 100644 --- a/rust/lance/src/index/vector/pq.rs +++ b/rust/lance/src/index/vector/pq.rs @@ -236,7 +236,7 @@ impl VectorIndex for PQIndex { vec![distances, row_ids], )?) } else { - let indices = sort_to_indices(&distances, None, Some(limit))?; + let indices = sort_to_indices(&distances, None, None)?; let mut dists = Vec::with_capacity(limit); let mut ids = Vec::with_capacity(limit); for idx in indices.values().iter() { From cfbf1532a1c9cefb8a0b5c27daa6381301ec190f Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 2 Jan 2025 18:12:30 +0800 Subject: [PATCH 3/5] low recall requirement for 4bit Signed-off-by: BubbleCal --- rust/lance/src/index/vector/ivf/v2.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index c1c19bd6e4..9da1acb833 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -871,8 +871,8 @@ mod tests { } #[rstest] - #[case(4, DistanceType::L2, 0.9)] - #[case(4, DistanceType::Cosine, 0.9)] + #[case(4, DistanceType::L2, 0.85)] + #[case(4, DistanceType::Cosine, 0.85)] #[case(4, DistanceType::Dot, 0.8)] #[tokio::test] async fn test_create_ivf_hnsw_pq_4bit( From e71d61f1bb9c82aff765b5fd86d4ab06fdf45596 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Fri, 3 Jan 2025 13:06:37 +0800 Subject: [PATCH 4/5] optimize by binary search Signed-off-by: BubbleCal --- rust/lance-index/src/vector/flat/index.rs | 72 +++++++++++------------ 1 file changed, 34 insertions(+), 38 deletions(-) diff --git a/rust/lance-index/src/vector/flat/index.rs b/rust/lance-index/src/vector/flat/index.rs index f85ef30803..cd5cb9abf5 100644 --- a/rust/lance-index/src/vector/flat/index.rs +++ b/rust/lance-index/src/vector/flat/index.rs @@ -11,7 +11,6 @@ use arrow::array::AsArray; use arrow_array::{Array, ArrayRef, Float32Array, RecordBatch, UInt64Array}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use deepsize::DeepSizeOf; -use itertools::Itertools; use lance_core::{Error, Result, ROW_ID_FIELD}; use lance_file::reader::FileReader; use lance_linalg::distance::DistanceType; @@ -84,48 +83,45 @@ impl IvfSubIndex for FlatIndex { ) -> Result { let dist_calc = storage.dist_calculator(query); - let (row_ids, dists): (Vec, Vec) = match prefilter.is_empty() { - true => dist_calc - .distance_all() - .into_iter() - .zip(0..storage.len() as u32) - .map(|(dist, id)| OrderedNode { - id, - dist: OrderedFloat(dist), - }) - .sorted_unstable() - .skip_while(|r| params.lower_bound.map_or(false, |lb| r.dist.0 < lb)) - .take_while(|r| params.upper_bound.map_or(true, |ub| r.dist.0 < ub)) - .take(k) - .map( - |OrderedNode { - id, - dist: OrderedFloat(dist), - }| (storage.row_id(id), dist), - ) - .unzip(), + let mut res = match prefilter.is_empty() { + true => Vec::from_iter( + dist_calc + .distance_all() + .into_iter() + .zip(0..storage.len() as u32) + .map(|(dist, id)| OrderedNode { + id, + dist: OrderedFloat(dist), + }), + ), false => { let row_id_mask = prefilter.mask(); - (0..storage.len()) - .filter(|&id| row_id_mask.selected(storage.row_id(id as u32))) - .map(|id| OrderedNode { - id: id as u32, - dist: OrderedFloat(dist_calc.distance(id as u32)), - }) - .sorted_unstable() - .skip_while(|r| params.lower_bound.map_or(false, |lb| r.dist.0 < lb)) - .take_while(|r| params.upper_bound.map_or(true, |ub| r.dist.0 < ub)) - .take(k) - .map( - |OrderedNode { - id, - dist: OrderedFloat(dist), - }| (storage.row_id(id), dist), - ) - .unzip() + Vec::from_iter( + (0..storage.len()) + .filter(|&id| row_id_mask.selected(storage.row_id(id as u32))) + .map(|id| OrderedNode { + id: id as u32, + dist: OrderedFloat(dist_calc.distance(id as u32)), + }), + ) } }; + res.sort_unstable(); + + let filtered = if params.lower_bound.is_some() || params.upper_bound.is_some() { + let lower_bound = params.lower_bound.unwrap_or(f32::MIN); + let upper_bound = params.upper_bound.unwrap_or(f32::MAX); + let low_idx = res.partition_point(|r| r.dist.0 < lower_bound); + let high_idx = res.partition_point(|r| r.dist.0 < upper_bound); + res[low_idx..high_idx].iter() + } else { + res.iter() + }; + let (row_ids, dists): (Vec<_>, Vec<_>) = filtered + .take(k) + .map(|r| (storage.row_id(r.id), r.dist.0)) + .unzip(); let (row_ids, dists) = (UInt64Array::from(row_ids), Float32Array::from(dists)); Ok(RecordBatch::try_new( From 7d3ad66156fe4e8a825d18aaa35c0e497ea25f84 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Fri, 3 Jan 2025 13:47:53 +0800 Subject: [PATCH 5/5] fix Signed-off-by: BubbleCal --- rust/lance-index/src/vector/flat/index.rs | 54 +++++++++++++---------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/rust/lance-index/src/vector/flat/index.rs b/rust/lance-index/src/vector/flat/index.rs index cd5cb9abf5..af89b90270 100644 --- a/rust/lance-index/src/vector/flat/index.rs +++ b/rust/lance-index/src/vector/flat/index.rs @@ -83,42 +83,48 @@ impl IvfSubIndex for FlatIndex { ) -> Result { let dist_calc = storage.dist_calculator(query); - let mut res = match prefilter.is_empty() { - true => Vec::from_iter( - dist_calc + let mut res: Vec<_> = match prefilter.is_empty() { + true => { + let iter = dist_calc .distance_all() .into_iter() .zip(0..storage.len() as u32) .map(|(dist, id)| OrderedNode { id, dist: OrderedFloat(dist), - }), - ), + }); + + if params.lower_bound.is_some() || params.upper_bound.is_some() { + let lower_bound = params.lower_bound.unwrap_or(f32::MIN); + let upper_bound = params.upper_bound.unwrap_or(f32::MAX); + iter.filter(|r| lower_bound <= r.dist.0 && r.dist.0 < upper_bound) + .collect() + } else { + iter.collect() + } + } false => { let row_id_mask = prefilter.mask(); - Vec::from_iter( - (0..storage.len()) - .filter(|&id| row_id_mask.selected(storage.row_id(id as u32))) - .map(|id| OrderedNode { - id: id as u32, - dist: OrderedFloat(dist_calc.distance(id as u32)), - }), - ) + let iter = (0..storage.len()) + .filter(|&id| row_id_mask.selected(storage.row_id(id as u32))) + .map(|id| OrderedNode { + id: id as u32, + dist: OrderedFloat(dist_calc.distance(id as u32)), + }); + if params.lower_bound.is_some() || params.upper_bound.is_some() { + let lower_bound = params.lower_bound.unwrap_or(f32::MIN); + let upper_bound = params.upper_bound.unwrap_or(f32::MAX); + iter.filter(|r| lower_bound <= r.dist.0 && r.dist.0 < upper_bound) + .collect() + } else { + iter.collect() + } } }; res.sort_unstable(); - let filtered = if params.lower_bound.is_some() || params.upper_bound.is_some() { - let lower_bound = params.lower_bound.unwrap_or(f32::MIN); - let upper_bound = params.upper_bound.unwrap_or(f32::MAX); - let low_idx = res.partition_point(|r| r.dist.0 < lower_bound); - let high_idx = res.partition_point(|r| r.dist.0 < upper_bound); - res[low_idx..high_idx].iter() - } else { - res.iter() - }; - - let (row_ids, dists): (Vec<_>, Vec<_>) = filtered + let (row_ids, dists): (Vec<_>, Vec<_>) = res + .into_iter() .take(k) .map(|r| (storage.row_id(r.id), r.dist.0)) .unzip();