diff --git a/rust/src/index/vector.rs b/rust/src/index/vector.rs index b80319196d..dc063d9118 100644 --- a/rust/src/index/vector.rs +++ b/rust/src/index/vector.rs @@ -42,7 +42,7 @@ use crate::{ index::{ pb::vector_index_stage::Stage, vector::{ - diskann::DiskANNParams, + diskann::{DiskANNIndex, DiskANNParams}, ivf::Ivf, opq::{OPQIndex, OptimizedProductQuantizer}, pq::ProductQuantizer, @@ -316,10 +316,11 @@ pub(crate) async fn build_vector_index( } /// Open the Vector index on dataset, specified by the `uuid`. -pub(crate) async fn open_index<'a>( - dataset: &'a Dataset, +pub(crate) async fn open_index( + dataset: Arc, + column: &str, uuid: &str, -) -> Result> { +) -> Result> { if let Some(index) = dataset.session.index_cache.get(uuid) { return Ok(index); } @@ -362,11 +363,6 @@ pub(crate) async fn open_index<'a>( pb::index::Implementation::VectorIndex(vi) => vi, }; - let num_stages = vec_idx.stages.len(); - if num_stages != 2 && num_stages != 3 { - return Err(Error::IO("Only support IVF_(O)PQ now".to_string())); - }; - let metric_type = pb::VectorMetricType::from_i32(vec_idx.metric_type) .ok_or(Error::Index(format!( "Unsupported metric type value: {}", @@ -430,6 +426,18 @@ pub(crate) async fn open_index<'a>( let pq = Arc::new(ProductQuantizer::try_from(pq_proto).unwrap()); last_stage = Some(Arc::new(PQIndex::new(pq, metric_type))); } + Some(Stage::Diskann(diskann_proto)) => { + if last_stage.is_some() { + return Err(Error::Index(format!( + "DiskANN should be the only stage, but we got stages: {:?}", + vec_idx.stages + ))); + }; + let graph_path = index_dir.child(diskann_proto.filename.as_str()); + let diskann = + Arc::new(DiskANNIndex::try_new(dataset.clone(), column, &graph_path).await?); + last_stage = Some(diskann); + } _ => {} } } diff --git a/rust/src/index/vector/diskann.rs b/rust/src/index/vector/diskann.rs index 6d57bb64fa..184e05868d 100644 --- a/rust/src/index/vector/diskann.rs +++ b/rust/src/index/vector/diskann.rs @@ -16,7 +16,7 @@ /// /// Modified from diskann paper. The vector store is backed by the `lance` dataset. mod builder; -mod row_vertex; +pub(crate) mod row_vertex; mod search; use super::{ @@ -25,6 +25,8 @@ use super::{ }; use crate::index::vector::pq::PQBuildParams; pub(crate) use builder::build_diskann_index; +pub(crate) use row_vertex::RowVertex; +pub(crate) use search::DiskANNIndex; #[derive(Clone, Debug)] pub struct DiskANNParams { @@ -48,9 +50,9 @@ pub struct DiskANNParams { impl Default for DiskANNParams { fn default() -> Self { Self { - r: 90, + r: 50, alpha: 1.2, - l: 100, + l: 70, pq_params: PQBuildParams::default(), metric_type: MetricType::L2, } diff --git a/rust/src/index/vector/diskann/builder.rs b/rust/src/index/vector/diskann/builder.rs index bd64fa77f8..c59427acef 100644 --- a/rust/src/index/vector/diskann/builder.rs +++ b/rust/src/index/vector/diskann/builder.rs @@ -13,7 +13,9 @@ // limitations under the License. use std::collections::{BinaryHeap, HashSet}; +use std::sync::Arc; +use arrow_array::UInt32Array; use arrow_array::{cast::as_primitive_array, types::UInt64Type}; use arrow_select::concat::concat_batches; use futures::stream::{self, StreamExt, TryStreamExt}; @@ -74,7 +76,8 @@ pub(crate) async fn build_diskann_index( println!("DiskANN: second pass: {}s", now.elapsed().as_secs_f32()); let index_dir = dataset.indices_dir().child(uuid); - let graph_file = index_dir.child("diskann_graph.lance"); + let filename = "diskann_graph.lance"; + let graph_file = index_dir.child(filename); let mut write_params = WriteGraphParams::default(); write_params.batch_size = 2048 * 10; @@ -95,7 +98,7 @@ pub(crate) async fn build_diskann_index( name, uuid, graph.data.num_columns(), - graph_file.to_string().as_str(), + filename, &[medoid], params.metric_type, ¶ms, @@ -153,7 +156,8 @@ async fn init_graph( let distribution = Uniform::new(0, batch.num_rows()); // Randomly connect to r neighbors. for i in 0..graph.len() { - let mut neighbor_ids: HashSet = graph.neighbors(i)?.iter().copied().collect(); + let mut neighbor_ids: HashSet = + graph.neighbors(i).await?.values().iter().copied().collect(); while neighbor_ids.len() < r { let neighbor_id = rng.sample(distribution); @@ -162,17 +166,9 @@ async fn init_graph( } } - // Make bidirectional connections. { - let n = graph.neighbors_mut(i); - n.clear(); - n.extend(neighbor_ids.iter().copied()); - // Release mutable borrow on graph. - } - { - for neighbor_id in neighbor_ids.iter() { - graph.add_neighbor(*neighbor_id as usize, i); - } + let new_neighbors = Arc::new(UInt32Array::from_iter(neighbor_ids.iter().copied())); + graph.set_neighbors(i, new_neighbors); } } @@ -192,7 +188,7 @@ fn distance(matrix: &MatrixView, i: usize, j: usize) -> Result { } /// Algorithm 2 in the paper. -async fn robust_prune( +async fn robust_prune( graph: &GraphBuilder, id: usize, mut visited: HashSet, @@ -200,8 +196,8 @@ async fn robust_prune( r: usize, ) -> Result> { visited.remove(&id); - let neighbors = graph.neighbors(id)?; - visited.extend(neighbors.iter().map(|id| *id as usize)); + let neighbors = graph.neighbors(id).await?; + visited.extend(neighbors.values().iter().map(|id| *id as usize)); let mut heap: BinaryHeap = visited .iter() @@ -266,7 +262,7 @@ async fn find_medoid(vectors: &MatrixView, metric_type: MetricType) -> Result( +async fn index_once( graph: &mut GraphBuilder, medoid: usize, alpha: f32, @@ -283,20 +279,21 @@ async fn index_once( .row(i) .ok_or_else(|| Error::Index(format!("Cannot find vector with id {}", id)))?; - let state = greedy_search(graph, medoid, vector, 1, l)?; - - graph - .neighbors_mut(id) - .extend(state.visited.iter().map(|id| *id as u32)); + let state = greedy_search(graph, medoid, vector, 1, l).await?; let neighbors = robust_prune(graph, id, state.visited, alpha, r).await?; - graph.set_neighbors(id, neighbors.to_vec()); + graph.set_neighbors( + id, + Arc::new(neighbors.iter().copied().collect::()), + ); let fixed_graph: &GraphBuilder = graph; let neighbours = stream::iter(neighbors) .map(|j| async move { let mut neighbor_set: HashSet = fixed_graph - .neighbors(j as usize)? + .neighbors(j as usize) + .await? + .values() .iter() .map(|v| *v as usize) .collect(); @@ -316,7 +313,7 @@ async fn index_once( .try_collect::>() .await?; for (j, nbs) in neighbours { - graph.set_neighbors(j, nbs); + graph.set_neighbors(j, Arc::new(nbs.into_iter().collect::())); } } diff --git a/rust/src/index/vector/diskann/row_vertex.rs b/rust/src/index/vector/diskann/row_vertex.rs index c4306e9318..80b6ee058c 100644 --- a/rust/src/index/vector/diskann/row_vertex.rs +++ b/rust/src/index/vector/diskann/row_vertex.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::any::Any; + use arrow_array::Float32Array; use byteorder::{ByteOrder, LE}; @@ -23,7 +25,6 @@ use crate::Result; pub(crate) struct RowVertex { pub(crate) row_id: u64, - #[allow(dead_code)] pub(crate) vector: Option, } @@ -33,10 +34,28 @@ impl RowVertex { } } -impl Vertex for RowVertex {} +impl Vertex for RowVertex { + fn vector(&self) -> &[f32] { + self.vector.as_ref().unwrap().values() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } +} pub(crate) struct RowVertexSerDe {} +impl RowVertexSerDe { + pub(crate) fn new() -> Self { + Self {} + } +} + impl VertexSerDe for RowVertexSerDe { fn size(&self) -> usize { 8 diff --git a/rust/src/index/vector/diskann/search.rs b/rust/src/index/vector/diskann/search.rs index 4a971c4ab8..1af9ae150f 100644 --- a/rust/src/index/vector/diskann/search.rs +++ b/rust/src/index/vector/diskann/search.rs @@ -13,14 +13,41 @@ // limitations under the License. use std::{ + any::Any, cmp::Reverse, collections::{BTreeMap, BinaryHeap, HashSet}, + sync::Arc, }; +use arrow_array::{ArrayRef, Float32Array, RecordBatch, UInt64Array}; +use arrow_schema::{DataType, Field, Schema}; +use async_trait::async_trait; +use object_store::path::Path; use ordered_float::OrderedFloat; -use crate::index::vector::graph::{Graph, VertexWithDistance}; -use crate::Result; +use super::row_vertex::{RowVertex, RowVertexSerDe}; +use crate::{ + dataset::{Dataset, ROW_ID}, + index::{ + vector::{ + graph::{GraphReadParams, PersistedGraph}, + SCORE_COL, + }, + Index, + }, + Result, +}; +use crate::{ + index::{ + vector::VectorIndex, + vector::{ + graph::{Graph, VertexWithDistance}, + Query, + }, + }, + io::object_reader::ObjectReader, + Error, +}; /// DiskANN search state. pub(crate) struct SearchState { @@ -37,6 +64,14 @@ pub(crate) struct SearchState { /// Heap maintains the unvisited vertices, ordered by the distance. heap: BinaryHeap>, + /// Track the ones that have been computed distance with and pushed to heap already. + /// + /// This is different to visited, mostly because visited has a different meaning in the + /// paper, which is the one that has been popped from the heap. + /// But we wanted to avoid repeatly computing `argmin` over the heap, so we need another + /// meaning for visited. + heap_visisted: HashSet, + /// Search size, `L` parameter in the paper. L must be greater or equal than k. l: usize, @@ -53,6 +88,7 @@ impl SearchState { visited: HashSet::new(), candidates: BTreeMap::new(), heap: BinaryHeap::new(), + heap_visisted: HashSet::new(), k, l, } @@ -61,9 +97,7 @@ impl SearchState { /// Return the next unvisited vertex. fn pop(&mut self) -> Option { while let Some(vertex) = self.heap.pop() { - // println!("Pop {} visited {:?}", vertex.0.id, self.visited); - - if self.is_visited(vertex.0.id) || !self.candidates.contains_key(&vertex.0.distance) { + if !self.candidates.contains_key(&vertex.0.distance) { // The vertex has been removed from the candidate lists, // from [`push()`]. continue; @@ -75,8 +109,10 @@ impl SearchState { None } - /// Push a new (unvisited) fvertex into the search state. + /// Push a new (unvisited) vertex into the search state. fn push(&mut self, vertex_id: usize, distance: f32) { + assert!(!self.visited.contains(&vertex_id)); + self.heap_visisted.insert(vertex_id); self.heap .push(Reverse(VertexWithDistance::new(vertex_id, distance))); self.candidates.insert(OrderedFloat(distance), vertex_id); @@ -92,7 +128,7 @@ impl SearchState { /// Returns true if the vertex has been visited. fn is_visited(&self, vertex_id: usize) -> bool { - self.visited.contains(&vertex_id) + self.visited.contains(&vertex_id) || self.heap_visisted.contains(&vertex_id) } } @@ -105,8 +141,8 @@ impl SearchState { /// - query: The query vector. /// - k: The number of nearest neighbors to return. /// - search_size: Search list size, L in the paper. -pub(crate) fn greedy_search( - graph: &dyn Graph, +pub(crate) async fn greedy_search( + graph: &(dyn Graph + Send + Sync), start: usize, query: &[f32], k: usize, @@ -116,17 +152,19 @@ pub(crate) fn greedy_search( // A map from distance to vertex id. let mut state = SearchState::new(k, search_size); - let dist = graph.distance_to(query, start)?; + let dist = graph.distance_to(query, start).await?; state.push(start, dist); while let Some(id) = state.pop() { state.visit(id); - for neighbor_id in graph.neighbors(id)?.iter() { + + let neighbors = graph.neighbors(id).await?; + for neighbor_id in neighbors.values() { let neighbor_id = *neighbor_id as usize; if state.is_visited(neighbor_id) { // Already visited. continue; } - let dist = graph.distance_to(query, neighbor_id)?; + let dist = graph.distance_to(query, neighbor_id).await?; state.push(neighbor_id, dist); } } @@ -134,6 +172,81 @@ pub(crate) fn greedy_search( Ok(state) } +pub struct DiskANNIndex { + graph: PersistedGraph, +} + +impl std::fmt::Debug for DiskANNIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DiskANNIndex") + } +} + +impl DiskANNIndex { + /// Creates a new DiskANN index. + + pub async fn try_new( + dataset: Arc, + index_column: &str, + graph_path: &Path, + ) -> Result { + let params = GraphReadParams::default(); + let serde = Arc::new(RowVertexSerDe::new()); + let graph = + PersistedGraph::try_new(dataset, index_column, graph_path, params, serde).await?; + Ok(Self { graph }) + } +} + +impl Index for DiskANNIndex { + fn as_any(&self) -> &dyn Any { + self + } +} + +#[async_trait] +impl VectorIndex for DiskANNIndex { + async fn search(&self, query: &Query) -> Result { + let state = greedy_search(&self.graph, 0, query.key.values(), query.k, query.k * 2).await?; + let schema = Arc::new(Schema::new(vec![ + Field::new(ROW_ID, DataType::UInt64, false), + Field::new(SCORE_COL, DataType::Float32, false), + ])); + + let row_ids: UInt64Array = state + .candidates + .iter() + .take(query.k) + .map(|(_, id)| *id as u64) + .collect(); + let scores: Float32Array = state + .candidates + .iter() + .take(query.k) + .map(|(d, _)| **d) + .collect(); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(row_ids) as ArrayRef, Arc::new(scores) as ArrayRef], + )?; + Ok(batch) + } + + fn is_loadable(&self) -> bool { + false + } + + async fn load( + &self, + _reader: &dyn ObjectReader, + _offset: usize, + _length: usize, + ) -> Result> { + Err(Error::Index("DiskANNIndex is not loadable".to_string())) + } +} + #[cfg(test)] mod test { diff --git a/rust/src/index/vector/graph.rs b/rust/src/index/vector/graph.rs index a4cc61cb9f..052f084359 100644 --- a/rust/src/index/vector/graph.rs +++ b/rust/src/index/vector/graph.rs @@ -15,6 +15,11 @@ //! Graph-based vector index. //! +use std::any::Any; +use std::sync::Arc; + +use arrow_array::UInt32Array; +use async_trait::async_trait; use ordered_float::OrderedFloat; pub(crate) mod builder; @@ -24,17 +29,26 @@ use crate::Result; pub use persisted::*; /// Graph +#[async_trait] pub trait Graph { /// Distance between two vertices, specified by their IDs. - fn distance(&self, a: usize, b: usize) -> Result; + async fn distance(&self, a: usize, b: usize) -> Result; - fn distance_to(&self, query: &[f32], idx: usize) -> Result; + /// Distance from query vector to a vertex identified by the idx. + async fn distance_to(&self, query: &[f32], idx: usize) -> Result; - fn neighbors(&self, id: usize) -> Result<&[u32]>; + /// Return the neighbor IDs. + async fn neighbors(&self, id: usize) -> Result>; } /// Vertex (metadata). It does not include the actual data. -pub trait Vertex {} +pub trait Vertex: Clone { + fn as_any(&self) -> &dyn Any; + + fn as_any_mut(&mut self) -> &mut dyn Any; + + fn vector(&self) -> &[f32]; +} /// Vertex SerDe. Used for serializing and deserializing the vertex. pub(crate) trait VertexSerDe { diff --git a/rust/src/index/vector/graph/builder.rs b/rust/src/index/vector/graph/builder.rs index 9f25cd7a38..dada46d797 100644 --- a/rust/src/index/vector/graph/builder.rs +++ b/rust/src/index/vector/graph/builder.rs @@ -16,6 +16,9 @@ use std::sync::Arc; +use arrow_array::UInt32Array; +use async_trait::async_trait; + use super::{Graph, Vertex}; use crate::arrow::linalg::MatrixView; use crate::index::vector::MetricType; @@ -29,13 +32,13 @@ pub(crate) struct Node { /// Neighbors are the ids of vertex in the graph. /// This id is not the same as the row_id in the original lance dataset. - pub(crate) neighbors: Vec, + pub(crate) neighbors: Arc, } /// A Graph that allows dynamically build graph to be persisted later. /// /// It requires all vertices to be of the same size. -pub(crate) struct GraphBuilder { +pub(crate) struct GraphBuilder { pub(crate) nodes: Vec>, /// Hold all vectors in memory for fast access at the moment. @@ -48,14 +51,14 @@ pub(crate) struct GraphBuilder { distance_func: Arc f32 + Send + Sync>, } -impl<'a, V: Vertex + Clone> GraphBuilder { +impl<'a, V: Vertex + Clone + Sync + Send> GraphBuilder { pub fn new(vertices: &[V], data: MatrixView, metric_type: MetricType) -> Self { Self { nodes: vertices .iter() .map(|v| Node { vertex: v.clone(), - neighbors: Vec::new(), + neighbors: Arc::new(UInt32Array::from(vec![] as Vec)), }) .collect(), data, @@ -80,23 +83,15 @@ impl<'a, V: Vertex + Clone> GraphBuilder { &mut self.nodes[id].vertex } - pub fn neighbors_mut(&mut self, id: usize) -> &mut Vec { - &mut self.nodes[id].neighbors - } - /// Set neighbors of a node. - pub fn set_neighbors(&mut self, id: usize, neighbors: impl Into>) { - self.nodes[id].neighbors = neighbors.into(); - } - - /// Add a neighbor to a specific vertex. - pub fn add_neighbor(&mut self, vertex: usize, neighbor: usize) { - self.nodes[vertex].neighbors.push(neighbor as u32); + pub fn set_neighbors(&mut self, id: usize, neighbors: Arc) { + self.nodes[id].neighbors = neighbors; } } -impl Graph for GraphBuilder { - fn distance(&self, a: usize, b: usize) -> Result { +#[async_trait] +impl Graph for GraphBuilder { + async fn distance(&self, a: usize, b: usize) -> Result { let vector_a = self.data.row(a).ok_or_else(|| { Error::Index(format!( "Vector index is out of range: {} >= {}", @@ -115,7 +110,7 @@ impl Graph for GraphBuilder { Ok((self.distance_func)(vector_a, vector_b)) } - fn distance_to(&self, query: &[f32], idx: usize) -> Result { + async fn distance_to(&self, query: &[f32], idx: usize) -> Result { let vector = self.data.row(idx).ok_or_else(|| { Error::Index(format!( "Attempt to access row {} in a matrix with {} rows", @@ -126,8 +121,8 @@ impl Graph for GraphBuilder { Ok((self.distance_func)(query, vector)) } - fn neighbors(&self, id: usize) -> Result<&[u32]> { - Ok(self.nodes[id].neighbors.as_slice()) + async fn neighbors(&self, id: usize) -> Result> { + Ok(self.nodes[id].neighbors.clone()) } } @@ -143,10 +138,22 @@ mod tests { val: f32, } - impl Vertex for FooVertex {} + impl Vertex for FooVertex { + fn vector(&self) -> &[f32] { + todo!() + } + + fn as_any(&self) -> &dyn std::any::Any { + todo!() + } + + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + todo!() + } + } - #[test] - fn test_construct_builder() { + #[tokio::test] + async fn test_construct_builder() { let nodes = (0..100) .map(|v| FooVertex { id: v as u32, @@ -158,7 +165,7 @@ mod tests { assert_eq!(builder.len(), 100); assert_eq!(builder.vertex(77).id, 77); assert_relative_eq!(builder.vertex(77).val, 38.5); - assert!(builder.neighbors(55).unwrap().is_empty()); + assert!(builder.neighbors(55).await.unwrap().is_empty()); builder.vertex_mut(88).val = 22.0; assert_relative_eq!(builder.vertex(88).val, 22.0); diff --git a/rust/src/index/vector/graph/persisted.rs b/rust/src/index/vector/graph/persisted.rs index 2d789c84e6..2fa5fed42d 100644 --- a/rust/src/index/vector/graph/persisted.rs +++ b/rust/src/index/vector/graph/persisted.rs @@ -12,22 +12,29 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::fmt::Debug; use std::sync::{Arc, Mutex}; use arrow::array::{as_list_array, as_primitive_array}; +use arrow_array::cast::AsArray; +use arrow_array::Float32Array; use arrow_array::{ builder::{FixedSizeBinaryBuilder, ListBuilder, UInt32Builder}, Array, RecordBatch, UInt32Array, }; use arrow_schema::{DataType, Field, Schema as ArrowSchema}; +use async_trait::async_trait; use lru_time_cache::LruCache; use object_store::path::Path; -use super::builder::GraphBuilder; +use super::{builder::GraphBuilder, Graph}; use super::{Vertex, VertexSerDe}; -use crate::arrow::as_fixed_size_binary_array; +use crate::arrow::as_fixed_size_list_array; +use crate::dataset::Dataset; use crate::datatypes::Schema; +use crate::index::vector::diskann::RowVertex; use crate::io::{FileReader, FileWriter, ObjectStore}; +use crate::{arrow::as_fixed_size_binary_array, linalg::l2::L2}; use crate::{Error, Result}; const NEIGHBORS_COL: &str = "neighbors"; @@ -53,7 +60,13 @@ impl Default for GraphReadParams { } /// Persisted graph on disk, stored in the file. -pub(crate) struct PersistedGraph { +pub(crate) struct PersistedGraph { + /// Reference to the dataset. + dataset: Arc, + + /// Vector column. + vector_column_projection: Schema, + reader: FileReader, /// Vertex size in bytes. @@ -75,17 +88,19 @@ pub(crate) struct PersistedGraph { params: GraphReadParams, /// SerDe for vertex. - serde: Box>, + serde: Arc + Send + Sync>, } -impl PersistedGraph { +impl PersistedGraph { /// Try open a persisted graph from a given URI. pub(crate) async fn try_new( - object_store: &ObjectStore, + dataset: Arc, + vector_column: &str, path: &Path, params: GraphReadParams, - serde: Box>, + serde: Arc + Send + Sync>, ) -> Result> { + let object_store = dataset.object_store(); let file_reader = FileReader::try_new(object_store, path).await?; let schema = file_reader.schema(); @@ -107,7 +122,11 @@ impl PersistedGraph { }; let neighbors_projection = schema.project(&[NEIGHBORS_COL])?; + let vector_column_projection = dataset.schema().project(&[vector_column])?; + Ok(Self { + dataset, + vector_column_projection, reader: file_reader, vertex_size, vertex_projection, @@ -136,20 +155,37 @@ impl PersistedGraph { return Ok(vertex.clone()); } } - let prefetch_size = self.params.prefetch_byte_size / self.vertex_size + 1; - let end = std::cmp::min(self.len(), id as usize + prefetch_size); + let end = (id + 1) as usize; let batch = self .reader - .read_range(id as usize..end, &self.vertex_projection) + .read_range(id as usize..(id + 1) as usize, &self.vertex_projection) .await?; assert_eq!(batch.num_rows(), end - id as usize); + + let array = as_fixed_size_binary_array(batch.column(0)); + let mut vertices = vec![]; + for vertex_bytes in array.iter() { + let mut vertex = self.serde.deserialize(vertex_bytes.unwrap())?; + let mut row_vector = vertex.as_any_mut().downcast_mut::().unwrap(); + let batch = self + .dataset + .take_rows(&[row_vector.row_id as u64], &self.vector_column_projection) + .await?; + + let column = as_fixed_size_list_array(batch.column(0)); + let values = column.value(0); + let vector: Float32Array = values.as_primitive().clone(); + row_vector.vector = Some(vector); + vertices.push(vertex); + } + { let mut cache = self.cache.lock().unwrap(); - let array = as_fixed_size_binary_array(batch.column(0)); - for (i, vertex_bytes) in array.iter().enumerate() { - let vertex = self.serde.deserialize(vertex_bytes.unwrap())?; + for i in 0..vertices.len() { + let vertex = vertices[i].clone(); cache.insert(id + i as u32, Arc::new(vertex)); } + Ok(cache.get(&id).unwrap().clone()) } } @@ -182,6 +218,46 @@ impl PersistedGraph { } } +#[async_trait] +impl Graph for PersistedGraph { + async fn distance(&self, a: usize, b: usize) -> Result { + let vertex_a = self.vertex(a as u32).await?; + self.distance_to(vertex_a.vector(), b).await + } + + async fn distance_to(&self, query: &[f32], idx: usize) -> Result { + let vertex = self.vertex(idx as u32).await?; + Ok(vertex.vector().l2(query)) + } + + /// Get the neighbors of a vertex, specified by its id. + async fn neighbors(&self, id: usize) -> Result> { + { + let mut cache = self.neighbors_cache.lock().unwrap(); + if let Some(neighbors) = cache.get(&(id as u32)) { + return Ok(neighbors.clone()); + } + } + let batch = self + .reader + .read_range(id as usize..(id + 1) as usize, &self.neighbors_projection) + .await?; + { + let mut cache = self.neighbors_cache.lock().unwrap(); + + let array = as_list_array(batch.column(0)); + if array.len() < 1 { + return Err(Error::Index("Invalid graph".to_string())); + } + let value = array.value(0); + let nb_array: &UInt32Array = as_primitive_array(value.as_ref()); + let neighbors = Arc::new(nb_array.clone()); + cache.insert(id as u32, neighbors.clone()); + Ok(neighbors.clone()) + } + } +} + /// Parameters for writing the graph index. pub struct WriteGraphParams { pub batch_size: usize, @@ -194,7 +270,7 @@ impl Default for WriteGraphParams { } /// Write the graph to a file. -pub(crate) async fn write_graph( +pub(crate) async fn write_graph( graph: &GraphBuilder, object_store: &ObjectStore, path: &Path, @@ -231,7 +307,7 @@ pub(crate) async fn write_graph( vertex_builder.append_value(serde.serialize(&node.vertex))?; neighbors_builder .values() - .append_slice(node.neighbors.as_slice()); + .append_slice(node.neighbors.values()); neighbors_builder.append(true); } let batch = RecordBatch::try_new( @@ -252,8 +328,16 @@ pub(crate) async fn write_graph( #[cfg(test)] mod tests { + use arrow_array::{FixedSizeListArray, RecordBatchReader}; + use super::*; - use crate::{arrow::linalg::MatrixView, index::vector::MetricType}; + use crate::{ + arrow::{linalg::MatrixView, FixedSizeListArrayExt, RecordBatchBuffer}, + dataset::WriteParams, + index::vector::diskann::row_vertex::RowVertexSerDe, + index::vector::MetricType, + utils::testing::generate_random_array, + }; #[derive(Clone, Debug)] struct FooVertex { @@ -262,7 +346,19 @@ mod tests { pq: Vec, } - impl Vertex for FooVertex {} + impl Vertex for FooVertex { + fn vector(&self) -> &[f32] { + unimplemented!() + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } + } struct FooVertexSerDe {} @@ -287,36 +383,70 @@ mod tests { #[tokio::test] async fn test_persisted_graph() { - let store = ObjectStore::memory(); - let path = Path::from("/graph"); - - let nodes = (0..100) - .map(|v| FooVertex { - row_id: v as u32, - pq: vec![0; 16], + use tempfile::tempdir; + + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let total = 100; + let dim = 32; + + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + "vector", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + dim as i32, + ), + true, + )])); + let data = generate_random_array(total * dim); + let batches = RecordBatchBuffer::new(vec![RecordBatch::try_new( + schema.clone(), + vec![Arc::new( + FixedSizeListArray::try_new(&data, dim as i32).unwrap(), + )], + ) + .unwrap()]); + + let mut write_params = WriteParams::default(); + write_params.max_rows_per_file = 40; + write_params.max_rows_per_group = 10; + let mut batches: Box = Box::new(batches); + let dataset = Dataset::write(&mut batches, test_uri, Some(write_params)) + .await + .unwrap(); + + let graph_path = dataset.indices_dir().child("graph"); + let nodes = (0..total) + .map(|v| RowVertex { + row_id: v as u64, + vector: Some(generate_random_array(dim).into()), }) .collect::>(); let mut builder = GraphBuilder::new(&nodes, MatrixView::random(100, 16), MetricType::L2); - for i in 0..100 { - for j in i..i + 10 { - builder.add_neighbor(i, j); - } + for i in 0..total as u32 { + let neighbors = Arc::new(UInt32Array::from_iter_values(i..i + 10)); + builder.set_neighbors(i as usize, neighbors); } - let serde = Box::new(FooVertexSerDe {}); + let serde = Arc::new(RowVertexSerDe {}); write_graph( &builder, - &store, - &path, + dataset.object_store(), + &graph_path, &WriteGraphParams::default(), serde.as_ref(), ) .await .unwrap(); - let graph = - PersistedGraph::::try_new(&store, &path, GraphReadParams::default(), serde) - .await - .unwrap(); + let graph = PersistedGraph::::try_new( + Arc::new(dataset), + "vector", + &graph_path, + GraphReadParams::default(), + serde, + ) + .await + .unwrap(); let vertex = graph.vertex(77).await.unwrap(); assert_eq!(vertex.row_id, 77); diff --git a/rust/src/index/vector/opq.rs b/rust/src/index/vector/opq.rs index 232f5a7533..1b57b8d9a7 100644 --- a/rust/src/index/vector/opq.rs +++ b/rust/src/index/vector/opq.rs @@ -390,7 +390,9 @@ mod tests { .unwrap(); let uuid = index_file.file_name().to_str().unwrap().to_string(); - let index = open_index(&dataset, &uuid).await.unwrap(); + let index = open_index(Arc::new(dataset), "vector", &uuid) + .await + .unwrap(); if with_opq { let opq_idx = index.as_any().downcast_ref::().unwrap(); diff --git a/rust/src/io/exec/knn.rs b/rust/src/io/exec/knn.rs index a6d3d7fd33..22b8607b7d 100644 --- a/rust/src/io/exec/knn.rs +++ b/rust/src/io/exec/knn.rs @@ -220,7 +220,7 @@ impl KNNIndexStream { let q = query.clone(); let name = index_name.to_string(); let bg_thread = tokio::spawn(async move { - let index = match open_index(dataset.as_ref(), &name).await { + let index = match open_index(dataset, &q.column, &name).await { Ok(idx) => idx, Err(e) => { tx.send(Err(datafusion::error::DataFusionError::Execution(format!( diff --git a/rust/src/linalg/l2.rs b/rust/src/linalg/l2.rs index 426f439d8e..db52b5aefc 100644 --- a/rust/src/linalg/l2.rs +++ b/rust/src/linalg/l2.rs @@ -80,6 +80,7 @@ impl L2 for Float32Array { } /// Compute L2 distance between two vectors. +#[inline] pub fn l2_distance(from: &[f32], to: &[f32]) -> f32 { from.l2(to) }