diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index 9f787fa80a..9193261a1b 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -661,6 +661,7 @@ mod test { use super::*; use crate::arrow::*; use crate::dataset::WriteMode; + use crate::index::vector::diskann::DiskANNParams; use crate::index::{ DatasetIndexExt, {vector::VectorIndexParams, IndexType}, @@ -1586,4 +1587,108 @@ mod test { .unwrap(); concat_batches(&batches[0].schema(), &batches).unwrap(); } + + #[tokio::test] + async fn test_ann_with_deletion() { + let vec_params = vec![ + VectorIndexParams::with_diskann_params(MetricType::L2, DiskANNParams::new(10, 1.5, 10)), + VectorIndexParams::ivf_pq(4, 8, 2, false, MetricType::L2, 2), + ]; + for params in vec_params { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + + // make dataset + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("i", DataType::Int32, true), + ArrowField::new( + "vec", + DataType::FixedSizeList( + Arc::new(ArrowField::new("item", DataType::Float32, true)), + 32, + ), + true, + ), + ])); + + // vectors are [0, 0, 0, ...] [1, 1, 1, ...] + let vector_values: Float32Array = (0..32 * 512).map(|v| (v / 32) as f32).collect(); + let vectors = FixedSizeListArray::try_new(&vector_values, 32).unwrap(); + + let batches = RecordBatchBuffer::new(vec![RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from_iter_values(0..512)), + Arc::new(vectors), + ], + ) + .unwrap()]); + + let mut reader: Box = Box::new(batches); + let mut dataset = Dataset::write(&mut reader, test_uri, None).await.unwrap(); + + dataset + .create_index( + &["vec"], + IndexType::Vector, + Some("idx".to_string()), + ¶ms, + true, + ) + .await + .unwrap(); + + let mut scan = dataset.scan(); + // closest be i = 0..5 + let key: Float32Array = (0..32).map(|_v| 1.0 as f32).collect(); + scan.nearest("vec", &key, 5).unwrap(); + + let results = scan + .try_into_stream() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + + assert_eq!(results.len(), 1); + let batch = &results[0]; + + let expected_i = BTreeSet::from_iter(vec![0, 1, 2, 3, 4]); + let column_i = batch.column_by_name("i").unwrap(); + let actual_i: BTreeSet = as_primitive_array::(column_i.as_ref()) + .values() + .iter() + .copied() + .collect(); + assert_eq!(expected_i, actual_i); + + // DELETE top result and search again + + dataset.delete("i = 1").await.unwrap(); + let mut scan = dataset.scan(); + scan.nearest("vec", &key, 5).unwrap(); + + let results = scan + .try_into_stream() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + + assert_eq!(results.len(), 1); + let batch = &results[0]; + + // i=1 was deleted, and 5 is the next best, the reset shouldn't change + let expected_i = BTreeSet::from_iter(vec![0, 2, 3, 4, 5]); + let column_i = batch.column_by_name("i").unwrap(); + let actual_i: BTreeSet = as_primitive_array::(column_i.as_ref()) + .values() + .iter() + .copied() + .collect(); + assert_eq!(expected_i, actual_i); + } + } } diff --git a/rust/src/index/vector.rs b/rust/src/index/vector.rs index 12aedf1879..1a1af2bd96 100644 --- a/rust/src/index/vector.rs +++ b/rust/src/index/vector.rs @@ -49,6 +49,7 @@ use crate::{ }, }, io::{ + deletion::LruDeletionVectorStore, object_reader::{read_message, ObjectReader}, read_message_from_buf, read_metadata_offset, }, @@ -372,6 +373,13 @@ pub(crate) async fn open_index( .into(); let mut last_stage: Option> = None; + + let deletion_cache = Arc::new(LruDeletionVectorStore::new( + Arc::new(dataset.object_store().clone()), + object_store.base_path().clone(), + dataset.manifest.clone(), + 100_usize, + )); for stg in vec_idx.stages.iter().rev() { match stg.stage.as_ref() { Some(Stage::Transform(tf)) => { @@ -422,7 +430,11 @@ pub(crate) async fn open_index( }); }; let pq = Arc::new(ProductQuantizer::try_from(pq_proto).unwrap()); - last_stage = Some(Arc::new(PQIndex::new(pq, metric_type))); + last_stage = Some(Arc::new(PQIndex::new( + pq, + metric_type, + deletion_cache.clone(), + ))); } Some(Stage::Diskann(diskann_proto)) => { if last_stage.is_some() { @@ -434,8 +446,15 @@ pub(crate) async fn open_index( }); }; let graph_path = index_dir.child(diskann_proto.filename.as_str()); - let diskann = - Arc::new(DiskANNIndex::try_new(dataset.clone(), column, &graph_path).await?); + let diskann = Arc::new( + DiskANNIndex::try_new( + dataset.clone(), + column, + &graph_path, + deletion_cache.clone(), + ) + .await?, + ); last_stage = Some(diskann); } _ => {} diff --git a/rust/src/index/vector/diskann/search.rs b/rust/src/index/vector/diskann/search.rs index a240af6400..f86eaa2daa 100644 --- a/rust/src/index/vector/diskann/search.rs +++ b/rust/src/index/vector/diskann/search.rs @@ -35,6 +35,7 @@ use crate::{ }, Index, }, + io::deletion::LruDeletionVectorStore, Result, }; use crate::{ @@ -175,6 +176,8 @@ pub(crate) async fn greedy_search( pub struct DiskANNIndex { graph: PersistedGraph, + + deletion_cache: Arc, } impl std::fmt::Debug for DiskANNIndex { @@ -190,12 +193,16 @@ impl DiskANNIndex { dataset: Arc, index_column: &str, graph_path: &Path, + deletion_cache: Arc, ) -> Result { let params = GraphReadParams::default(); let serde = Arc::new(RowVertexSerDe::new()); let graph = PersistedGraph::try_new(dataset, index_column, graph_path, params, serde).await?; - Ok(Self { graph }) + Ok(Self { + graph, + deletion_cache, + }) } } @@ -214,18 +221,22 @@ impl VectorIndex for DiskANNIndex { Field::new(SCORE_COL, DataType::Float32, false), ])); - let row_ids: UInt64Array = state - .candidates + let mut candidates = Vec::with_capacity(query.k); + for (score, row) in state.candidates { + if candidates.len() == query.k { + break; + } + if !self.deletion_cache.as_ref().is_deleted(row as u64).await? { + candidates.push((score, row)); + } + } + + let row_ids: UInt64Array = candidates .iter() .take(query.k) .map(|(_, id)| *id as u64) .collect(); - let scores: Float32Array = state - .candidates - .iter() - .take(query.k) - .map(|(d, _)| **d) - .collect(); + let scores: Float32Array = candidates.iter().take(query.k).map(|(d, _)| **d).collect(); let batch = RecordBatch::try_new( schema, diff --git a/rust/src/index/vector/pq.rs b/rust/src/index/vector/pq.rs index 995090d9ea..2632331e95 100644 --- a/rust/src/index/vector/pq.rs +++ b/rust/src/index/vector/pq.rs @@ -17,6 +17,8 @@ use std::sync::Arc; use arrow::datatypes::Float32Type; use arrow_arith::aggregate::min; +use arrow_array::builder::UInt64Builder; +use arrow_array::types::UInt64Type; use arrow_array::{ builder::Float32Builder, cast::as_primitive_array, Array, ArrayRef, FixedSizeListArray, Float32Array, RecordBatch, UInt64Array, UInt8Array, @@ -34,6 +36,7 @@ use crate::arrow::*; use crate::dataset::ROW_ID; use crate::index::Index; use crate::index::{pb, vector::kmeans::train_kmeans, vector::SCORE_COL}; +use crate::io::deletion::LruDeletionVectorStore; use crate::io::object_reader::{read_fixed_stride_array, ObjectReader}; use crate::linalg::{l2::l2_distance_batch, norm_l2::norm_l2}; use crate::{Error, Result}; @@ -63,6 +66,9 @@ pub struct PQIndex { /// Metric type. metric_type: MetricType, + + /// Deletion vector cache. + deletion_lookup_cache: Arc, } impl std::fmt::Debug for PQIndex { @@ -77,7 +83,11 @@ impl std::fmt::Debug for PQIndex { impl PQIndex { /// Load a PQ index (page) from the disk. - pub(crate) fn new(pq: Arc, metric_type: MetricType) -> Self { + pub(crate) fn new( + pq: Arc, + metric_type: MetricType, + deletion_cache: Arc, + ) -> Self { Self { nbits: pq.num_bits, num_sub_vectors: pq.num_sub_vectors, @@ -86,6 +96,7 @@ impl PQIndex { row_ids: None, pq, metric_type, + deletion_lookup_cache: deletion_cache, } } @@ -257,14 +268,26 @@ impl VectorIndex for PQIndex { let row_ids = read_fixed_stride_array(reader, &DataType::UInt64, row_id_offset, length, ..).await?; + let mut filtered_row_id_builder = UInt64Builder::new(); + let deletion_checker = self.deletion_lookup_cache.as_ref(); + // TODO: consider a more optimized way of reading + // group by frag_id and check per frag in one go + for row_id in as_primitive_array::(row_ids.as_ref()) { + let row = row_id.expect("Found null row id."); + if !deletion_checker.is_deleted(row).await? { + filtered_row_id_builder.append_value(row); + } + } + Ok(Arc::new(Self { nbits: self.pq.num_bits, num_sub_vectors: self.pq.num_sub_vectors, dimension: self.pq.dimension, code: Some(Arc::new(as_primitive_array(&pq_code).clone())), - row_ids: Some(Arc::new(as_primitive_array(&row_ids).clone())), + row_ids: Some(Arc::new(filtered_row_id_builder.finish())), pq: self.pq.clone(), metric_type: self.metric_type, + deletion_lookup_cache: self.deletion_lookup_cache.clone(), })) } } diff --git a/rust/src/io/deletion.rs b/rust/src/io/deletion.rs index c9189f255a..40d853ddba 100644 --- a/rust/src/io/deletion.rs +++ b/rust/src/io/deletion.rs @@ -1,5 +1,6 @@ use std::ops::Range; use std::slice::Iter; +use std::sync::Mutex; use std::{collections::HashSet, sync::Arc}; use arrow::ipc::reader::FileReader as ArrowFileReader; @@ -8,6 +9,7 @@ use arrow::ipc::CompressionType; use arrow_array::{BooleanArray, RecordBatch, UInt32Array}; use arrow_schema::{ArrowError, DataType, Field, Schema}; use bytes::Buf; +use lru_time_cache::LruCache; use object_store::path::Path; use rand::Rng; use roaring::bitmap::RoaringBitmap; @@ -16,7 +18,7 @@ use snafu::ResultExt; use super::ObjectStore; use crate::dataset::DELETION_DIRS; use crate::error::{box_error, CorruptFileSnafu}; -use crate::format::{DeletionFile, DeletionFileType, Fragment}; +use crate::format::{DeletionFile, DeletionFileType, Fragment, Manifest}; use crate::{Error, Result}; /// Threshold for when a DeletionVector::Set should be promoted to a DeletionVector::Bitmap. @@ -341,6 +343,74 @@ pub(crate) async fn read_deletion_file( } } +pub struct LruDeletionVectorStore { + // can't clone mutex, so need to arc it + cache: Arc>>>, + object_store: Arc, + path: Path, + manifest: Arc, +} + +impl LruDeletionVectorStore { + pub(crate) fn new( + object_store: Arc, + path: Path, + manifest: Arc, + cache_capacity: usize, + ) -> Self { + Self { + cache: Arc::new(Mutex::new(LruCache::with_capacity(cache_capacity))), + object_store, + path, + manifest, + } + } + + pub async fn is_deleted(&self, row_id: u64) -> Result { + let frag_id = row_id >> 32; + let local_row_id = row_id as u32; + + let deletion_vec = { + let val_in_cache: Option> = { + let mut cache = self.cache.lock().unwrap(); + cache.get(&frag_id).map(|v| v.clone()) + }; + + // Lock is released while we do IO so we block others or poision the lock + if val_in_cache.is_none() { + let fragment = self + .manifest + .as_ref() + .fragments + .as_ref() + .iter() + .find(|frag| frag.id == frag_id); + let dvec = match fragment { + Some(frag) => { + let dvec = read_deletion_file(&self.path, frag, self.object_store.as_ref()) + .await?; + dvec.unwrap_or(DeletionVector::NoDeletions) + } + None => DeletionVector::NoDeletions, + }; + + // IO is done, now lock again + let mut cache = self.cache.lock().unwrap(); + cache.insert(frag_id, Arc::new(dvec)); + cache.get(&frag_id).unwrap().clone() + } else { + val_in_cache.unwrap() + } + }; + + Ok(match deletion_vec.as_ref() { + DeletionVector::Bitmap(bitmap) => bitmap.contains(local_row_id), + DeletionVector::Set(set) => set.contains(&local_row_id), + DeletionVector::NoDeletions => false, + }) + } +} + #[cfg(test)] mod test { use super::*; diff --git a/rust/src/session.rs b/rust/src/session.rs index 7f85772055..0760dc53da 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -53,19 +53,37 @@ impl Default for Session { mod tests { use super::*; - use crate::index::vector::{ - pq::{PQIndex, ProductQuantizer}, - MetricType, + use crate::{ + datatypes::Schema, + format::Manifest, + index::vector::{ + pq::{PQIndex, ProductQuantizer}, + MetricType, + }, + io::{deletion::LruDeletionVectorStore, ObjectStore}, }; - use std::sync::Arc; + use std::{collections::HashMap, sync::Arc}; - #[test] - fn test_disable_index_cache() { + #[tokio::test] + async fn test_disable_index_cache() { let no_cache = Session::new(0); assert!(no_cache.index_cache.get("abc").is_none()); + let schema = Schema { + fields: vec![], + metadata: HashMap::new(), + }; + let manifest = Arc::new(Manifest::new(&schema, Arc::new(vec![]))); + let object_store = Arc::new(ObjectStore::from_uri("memory://").await.unwrap().0); + let deletion_cache = Arc::new(LruDeletionVectorStore::new( + object_store.clone(), + object_store.as_ref().base_path().clone(), + manifest, + 100, + )); + let pq = Arc::new(ProductQuantizer::new(1, 8, 1)); - let idx = Arc::new(PQIndex::new(pq, MetricType::L2)); + let idx = Arc::new(PQIndex::new(pq, MetricType::L2, deletion_cache)); no_cache.index_cache.insert("abc", idx); assert!(no_cache.index_cache.get("abc").is_none());