From b5fd692fe392c6d923df3635f298e0a5864ee5dd Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 24 Apr 2024 16:34:12 +0800 Subject: [PATCH 1/3] perf: do HNSW search with threads of CPU runtime Signed-off-by: BubbleCal --- rust/lance/src/index/vector/hnsw.rs | 13 +++++++----- rust/lance/src/index/vector/ivf.rs | 31 +++++++++++++++++++++++++---- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/rust/lance/src/index/vector/hnsw.rs b/rust/lance/src/index/vector/hnsw.rs index eec4cf7040..410839745e 100644 --- a/rust/lance/src/index/vector/hnsw.rs +++ b/rust/lance/src/index/vector/hnsw.rs @@ -10,6 +10,7 @@ use std::{ use arrow_array::{Float32Array, RecordBatch, UInt64Array}; use async_trait::async_trait; +use lance_core::utils::tokio::spawn_cpu; use lance_core::{datatypes::Schema, Error, Result}; use lance_file::reader::FileReader; use lance_index::{ @@ -45,7 +46,7 @@ pub(crate) struct HNSWIndexOptions { #[derive(Clone)] pub(crate) struct HNSWIndex { - hnsw: HNSW, + hnsw: Arc, // TODO: move these into IVFIndex after the refactor is complete partition_storage: IvfQuantizationStorage, @@ -79,7 +80,7 @@ impl HNSWIndex { let ivf_store = IvfQuantizationStorage::open(aux_reader).await?; Ok(Self { - hnsw, + hnsw: Arc::new(hnsw), partition_storage: ivf_store, partition_metadata, options, @@ -167,7 +168,9 @@ impl VectorIndex for HNSWIndex { }); } - let results = self.hnsw.search(query.key.clone(), k, ef, bitmap)?; + let hnsw = self.hnsw.clone(); + let key = query.key.clone(); + let results = spawn_cpu(move || Ok(hnsw.search(key, k, ef, bitmap)?)).await?; let row_ids = UInt64Array::from_iter_values(results.iter().map(|x| row_ids[x.id as usize])); let distances = Arc::new(Float32Array::from_iter_values( @@ -222,7 +225,7 @@ impl VectorIndex for HNSWIndex { .await?; Ok(Box::new(Self { - hnsw, + hnsw: Arc::new(hnsw), partition_storage: self.partition_storage.clone(), partition_metadata: self.partition_metadata.clone(), options: self.options.clone(), @@ -249,7 +252,7 @@ impl VectorIndex for HNSWIndex { .await?; Ok(Box::new(Self { - hnsw, + hnsw: Arc::new(hnsw), partition_storage: self.partition_storage.clone(), partition_metadata: self.partition_metadata.clone(), options: self.options.clone(), diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 8206f4ed65..c6ca66064d 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -2533,14 +2533,37 @@ mod tests { .map(|v| v.unwrap() as u32) .collect::>(); - let gt = ground_truth(&mat, query.values(), k); - let recall = results.intersection(>).count() as f32 / k as f32; + let row_ids = results[0] + .column_by_name(ROW_ID) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.unwrap() as u32) + .collect::>(); + let dists = results[0] + .column_by_name("_distance") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .values() + .to_vec(); + + let results = dists.into_iter().zip(row_ids.into_iter()).collect_vec(); + let gt = ground_truth(&mat, query.values(), k, DistanceType::L2); + + let results_set = results.iter().map(|r| r.1).collect::>(); + let gt_set = gt.iter().map(|r| r.1).collect::>(); + + let recall = results_set.intersection(>_set).count() as f32 / k as f32; assert!( recall >= 0.9, "recall: {}\n results: {:?}\n\ngt: {:?}", recall, - results.iter().sorted().collect_vec(), - gt.iter().sorted().collect_vec() + results, + gt, ); } From 3f99fd8b2f7a4e1d3eced970a560abf54f436454 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 24 Apr 2024 16:40:04 +0800 Subject: [PATCH 2/3] fix Signed-off-by: BubbleCal --- rust/lance/src/index/vector/ivf.rs | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index c6ca66064d..cfb9bceecc 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -2523,16 +2523,6 @@ mod tests { assert_eq!(1, results.len()); assert_eq!(k, results[0].num_rows()); - let results = results[0] - .column_by_name(ROW_ID) - .unwrap() - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .map(|v| v.unwrap() as u32) - .collect::>(); - let row_ids = results[0] .column_by_name(ROW_ID) .unwrap() From 8de49b2757ae2b5fee5f5c8a8e20a16841b55e04 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 24 Apr 2024 20:47:22 +0800 Subject: [PATCH 3/3] fmt Signed-off-by: BubbleCal --- rust/lance/src/index/vector/hnsw.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/lance/src/index/vector/hnsw.rs b/rust/lance/src/index/vector/hnsw.rs index 410839745e..d30d809851 100644 --- a/rust/lance/src/index/vector/hnsw.rs +++ b/rust/lance/src/index/vector/hnsw.rs @@ -170,7 +170,7 @@ impl VectorIndex for HNSWIndex { let hnsw = self.hnsw.clone(); let key = query.key.clone(); - let results = spawn_cpu(move || Ok(hnsw.search(key, k, ef, bitmap)?)).await?; + let results = spawn_cpu(move || hnsw.search(key, k, ef, bitmap)).await?; let row_ids = UInt64Array::from_iter_values(results.iter().map(|x| row_ids[x.id as usize])); let distances = Arc::new(Float32Array::from_iter_values(