Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement deletion vector handling in index scan #958

Merged
merged 1 commit into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions rust/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,7 @@ mod test {
use super::*;
use crate::arrow::*;
use crate::dataset::WriteMode;
use crate::index::vector::diskann::DiskANNParams;
use crate::index::{
DatasetIndexExt,
{vector::VectorIndexParams, IndexType},
Expand Down Expand Up @@ -1586,4 +1587,108 @@ mod test {
.unwrap();
concat_batches(&batches[0].schema(), &batches).unwrap();
}

#[tokio::test]
async fn test_ann_with_deletion() {
let vec_params = vec![
VectorIndexParams::with_diskann_params(MetricType::L2, DiskANNParams::new(10, 1.5, 10)),
VectorIndexParams::ivf_pq(4, 8, 2, false, MetricType::L2, 2),
];
for params in vec_params {
let test_dir = tempdir().unwrap();
let test_uri = test_dir.path().to_str().unwrap();

// make dataset
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("i", DataType::Int32, true),
ArrowField::new(
"vec",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
32,
),
true,
),
]));

// vectors are [0, 0, 0, ...] [1, 1, 1, ...]
let vector_values: Float32Array = (0..32 * 512).map(|v| (v / 32) as f32).collect();
let vectors = FixedSizeListArray::try_new(&vector_values, 32).unwrap();

let batches = RecordBatchBuffer::new(vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..512)),
Arc::new(vectors),
],
)
.unwrap()]);

let mut reader: Box<dyn RecordBatchReader> = Box::new(batches);
let mut dataset = Dataset::write(&mut reader, test_uri, None).await.unwrap();

dataset
.create_index(
&["vec"],
IndexType::Vector,
Some("idx".to_string()),
&params,
true,
)
.await
.unwrap();

let mut scan = dataset.scan();
// closest be i = 0..5
let key: Float32Array = (0..32).map(|_v| 1.0 as f32).collect();
scan.nearest("vec", &key, 5).unwrap();

let results = scan
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();

assert_eq!(results.len(), 1);
let batch = &results[0];

let expected_i = BTreeSet::from_iter(vec![0, 1, 2, 3, 4]);
let column_i = batch.column_by_name("i").unwrap();
let actual_i: BTreeSet<i32> = as_primitive_array::<Int32Type>(column_i.as_ref())
.values()
.iter()
.copied()
.collect();
assert_eq!(expected_i, actual_i);

// DELETE top result and search again

dataset.delete("i = 1").await.unwrap();
let mut scan = dataset.scan();
scan.nearest("vec", &key, 5).unwrap();

let results = scan
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();

assert_eq!(results.len(), 1);
let batch = &results[0];

// i=1 was deleted, and 5 is the next best, the reset shouldn't change
let expected_i = BTreeSet::from_iter(vec![0, 2, 3, 4, 5]);
let column_i = batch.column_by_name("i").unwrap();
let actual_i: BTreeSet<i32> = as_primitive_array::<Int32Type>(column_i.as_ref())
.values()
.iter()
.copied()
.collect();
assert_eq!(expected_i, actual_i);
}
}
}
25 changes: 22 additions & 3 deletions rust/src/index/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ use crate::{
},
},
io::{
deletion::LruDeletionVectorStore,
object_reader::{read_message, ObjectReader},
read_message_from_buf, read_metadata_offset,
},
Expand Down Expand Up @@ -372,6 +373,13 @@ pub(crate) async fn open_index(
.into();

let mut last_stage: Option<Arc<dyn VectorIndex>> = None;

let deletion_cache = Arc::new(LruDeletionVectorStore::new(
Arc::new(dataset.object_store().clone()),
object_store.base_path().clone(),
dataset.manifest.clone(),
100_usize,
));
for stg in vec_idx.stages.iter().rev() {
match stg.stage.as_ref() {
Some(Stage::Transform(tf)) => {
Expand Down Expand Up @@ -422,7 +430,11 @@ pub(crate) async fn open_index(
});
};
let pq = Arc::new(ProductQuantizer::try_from(pq_proto).unwrap());
last_stage = Some(Arc::new(PQIndex::new(pq, metric_type)));
last_stage = Some(Arc::new(PQIndex::new(
pq,
metric_type,
deletion_cache.clone(),
)));
}
Some(Stage::Diskann(diskann_proto)) => {
if last_stage.is_some() {
Expand All @@ -434,8 +446,15 @@ pub(crate) async fn open_index(
});
};
let graph_path = index_dir.child(diskann_proto.filename.as_str());
let diskann =
Arc::new(DiskANNIndex::try_new(dataset.clone(), column, &graph_path).await?);
let diskann = Arc::new(
DiskANNIndex::try_new(
dataset.clone(),
column,
&graph_path,
deletion_cache.clone(),
)
.await?,
);
last_stage = Some(diskann);
}
_ => {}
Expand Down
29 changes: 20 additions & 9 deletions rust/src/index/vector/diskann/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use crate::{
},
Index,
},
io::deletion::LruDeletionVectorStore,
Result,
};
use crate::{
Expand Down Expand Up @@ -175,6 +176,8 @@ pub(crate) async fn greedy_search(

pub struct DiskANNIndex {
graph: PersistedGraph<RowVertex>,

deletion_cache: Arc<LruDeletionVectorStore>,
}

impl std::fmt::Debug for DiskANNIndex {
Expand All @@ -190,12 +193,16 @@ impl DiskANNIndex {
dataset: Arc<Dataset>,
index_column: &str,
graph_path: &Path,
deletion_cache: Arc<LruDeletionVectorStore>,
) -> Result<Self> {
let params = GraphReadParams::default();
let serde = Arc::new(RowVertexSerDe::new());
let graph =
PersistedGraph::try_new(dataset, index_column, graph_path, params, serde).await?;
Ok(Self { graph })
Ok(Self {
graph,
deletion_cache,
})
}
}

Expand All @@ -214,18 +221,22 @@ impl VectorIndex for DiskANNIndex {
Field::new(SCORE_COL, DataType::Float32, false),
]));

let row_ids: UInt64Array = state
.candidates
let mut candidates = Vec::with_capacity(query.k);
for (score, row) in state.candidates {
if candidates.len() == query.k {
break;
}
if !self.deletion_cache.as_ref().is_deleted(row as u64).await? {
candidates.push((score, row));
}
}

let row_ids: UInt64Array = candidates
.iter()
.take(query.k)
.map(|(_, id)| *id as u64)
.collect();
let scores: Float32Array = state
.candidates
.iter()
.take(query.k)
.map(|(d, _)| **d)
.collect();
let scores: Float32Array = candidates.iter().take(query.k).map(|(d, _)| **d).collect();

let batch = RecordBatch::try_new(
schema,
Expand Down
27 changes: 25 additions & 2 deletions rust/src/index/vector/pq.rs
chebbyChefNEQ marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use std::sync::Arc;

use arrow::datatypes::Float32Type;
use arrow_arith::aggregate::min;
use arrow_array::builder::UInt64Builder;
use arrow_array::types::UInt64Type;
use arrow_array::{
builder::Float32Builder, cast::as_primitive_array, Array, ArrayRef, FixedSizeListArray,
Float32Array, RecordBatch, UInt64Array, UInt8Array,
Expand All @@ -34,6 +36,7 @@ use crate::arrow::*;
use crate::dataset::ROW_ID;
use crate::index::Index;
use crate::index::{pb, vector::kmeans::train_kmeans, vector::SCORE_COL};
use crate::io::deletion::LruDeletionVectorStore;
use crate::io::object_reader::{read_fixed_stride_array, ObjectReader};
use crate::linalg::{l2::l2_distance_batch, norm_l2::norm_l2};
use crate::{Error, Result};
Expand Down Expand Up @@ -63,6 +66,9 @@ pub struct PQIndex {

/// Metric type.
metric_type: MetricType,

/// Deletion vector cache.
deletion_lookup_cache: Arc<LruDeletionVectorStore>,
}

impl std::fmt::Debug for PQIndex {
Expand All @@ -77,7 +83,11 @@ impl std::fmt::Debug for PQIndex {

impl PQIndex {
/// Load a PQ index (page) from the disk.
pub(crate) fn new(pq: Arc<ProductQuantizer>, metric_type: MetricType) -> Self {
pub(crate) fn new(
pq: Arc<ProductQuantizer>,
metric_type: MetricType,
deletion_cache: Arc<LruDeletionVectorStore>,
) -> Self {
Self {
nbits: pq.num_bits,
num_sub_vectors: pq.num_sub_vectors,
Expand All @@ -86,6 +96,7 @@ impl PQIndex {
row_ids: None,
pq,
metric_type,
deletion_lookup_cache: deletion_cache,
}
}

Expand Down Expand Up @@ -257,14 +268,26 @@ impl VectorIndex for PQIndex {
let row_ids =
read_fixed_stride_array(reader, &DataType::UInt64, row_id_offset, length, ..).await?;

let mut filtered_row_id_builder = UInt64Builder::new();
let deletion_checker = self.deletion_lookup_cache.as_ref();
// TODO: consider a more optimized way of reading
// group by frag_id and check per frag in one go
for row_id in as_primitive_array::<UInt64Type>(row_ids.as_ref()) {
let row = row_id.expect("Found null row id.");
if !deletion_checker.is_deleted(row).await? {
filtered_row_id_builder.append_value(row);
}
}

Ok(Arc::new(Self {
nbits: self.pq.num_bits,
num_sub_vectors: self.pq.num_sub_vectors,
dimension: self.pq.dimension,
code: Some(Arc::new(as_primitive_array(&pq_code).clone())),
row_ids: Some(Arc::new(as_primitive_array(&row_ids).clone())),
row_ids: Some(Arc::new(filtered_row_id_builder.finish())),
pq: self.pq.clone(),
metric_type: self.metric_type,
deletion_lookup_cache: self.deletion_lookup_cache.clone(),
}))
}
}
Expand Down
Loading