diff --git a/rust/lance/src/index/vector/hnsw.rs b/rust/lance/src/index/vector/hnsw.rs index 61875f0d73..2c31529179 100644 --- a/rust/lance/src/index/vector/hnsw.rs +++ b/rust/lance/src/index/vector/hnsw.rs @@ -10,6 +10,7 @@ use std::{ use arrow_array::{Float32Array, RecordBatch, UInt64Array}; use async_trait::async_trait; +use lance_core::utils::tokio::spawn_cpu; use lance_core::{datatypes::Schema, Error, Result}; use lance_file::reader::FileReader; use lance_index::vector::quantizer::Quantizer; @@ -46,7 +47,7 @@ pub(crate) struct HNSWIndexOptions { #[derive(Clone)] pub(crate) struct HNSWIndex { - hnsw: HNSW, + hnsw: Arc, // TODO: move these into IVFIndex after the refactor is complete partition_storage: IvfQuantizationStorage, @@ -80,7 +81,7 @@ impl HNSWIndex { let ivf_store = IvfQuantizationStorage::open(aux_reader).await?; Ok(Self { - hnsw, + hnsw: Arc::new(hnsw), partition_storage: ivf_store, partition_metadata, options, @@ -176,7 +177,9 @@ impl VectorIndex for HNSWIndex { }); } - let results = self.hnsw.search(query.key.clone(), k, ef, bitmap)?; + let hnsw = self.hnsw.clone(); + let key = query.key.clone(); + let results = spawn_cpu(move || hnsw.search(key, k, ef, bitmap)).await?; let row_ids = UInt64Array::from_iter_values(results.iter().map(|x| row_ids[x.id as usize])); let distances = Arc::new(Float32Array::from_iter_values( @@ -231,7 +234,7 @@ impl VectorIndex for HNSWIndex { .await?; Ok(Box::new(Self { - hnsw, + hnsw: Arc::new(hnsw), partition_storage: self.partition_storage.clone(), partition_metadata: self.partition_metadata.clone(), options: self.options.clone(), @@ -258,7 +261,7 @@ impl VectorIndex for HNSWIndex { .await?; Ok(Box::new(Self { - hnsw, + hnsw: Arc::new(hnsw), partition_storage: self.partition_storage.clone(), partition_metadata: self.partition_metadata.clone(), options: self.options.clone(), diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index db02bf6856..5eca9ab488 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -2750,7 +2750,7 @@ mod tests { assert_eq!(1, results.len()); assert_eq!(k, results[0].num_rows()); - let results = results[0] + let row_ids = results[0] .column_by_name(ROW_ID) .unwrap() .as_any() @@ -2758,17 +2758,29 @@ mod tests { .unwrap() .iter() .map(|v| v.unwrap() as u32) - .collect::>(); + .collect::>(); + let dists = results[0] + .column_by_name("_distance") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .values() + .to_vec(); - let gt = ground_truth(&mat, query.values(), k, distance_type); + let results = dists.into_iter().zip(row_ids.into_iter()).collect_vec(); + let gt = ground_truth(&mat, query.values(), k, DistanceType::L2); + + let results_set = results.iter().map(|r| r.1).collect::>(); let gt_set = gt.iter().map(|r| r.1).collect::>(); - let recall = results.intersection(>_set).count() as f32 / k as f32; + + let recall = results_set.intersection(>_set).count() as f32 / k as f32; assert!( recall >= 0.9, "recall: {}\n results: {:?}\n\ngt: {:?}", recall, - results.iter().sorted().collect_vec(), - gt_set.iter().sorted().collect_vec() + results, + gt, ); }