Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Rust] DiskANN search #798

Merged
merged 33 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
3dc2197
DiskANNIndex as VectorIndex interface
eddyxu Apr 23, 2023
dc28e7c
compilable
eddyxu Apr 23, 2023
84b5f9c
compilable
eddyxu Apr 23, 2023
97d7a9b
persist
eddyxu Apr 25, 2023
8bb3728
make it compile
eddyxu Apr 26, 2023
75ae727
buildable
eddyxu Apr 26, 2023
259ca2a
build return
eddyxu Apr 26, 2023
37a25bd
neighbour w.o lock
eddyxu Apr 27, 2023
2598d78
cargo fmt
eddyxu May 13, 2023
eaa91ec
fix some compiling
eddyxu May 14, 2023
2f46808
fix some builds
eddyxu May 15, 2023
c5efd9f
pass all tests
eddyxu May 16, 2023
3a0b02c
cargo fmt
eddyxu May 16, 2023
13ee0d6
add some profiling
eddyxu May 17, 2023
fa8a4f0
more profiling
eddyxu May 17, 2023
d01b636
profiling
eddyxu May 17, 2023
b72fa75
fix visit semantic in graph
eddyxu May 18, 2023
f182ed2
minor
eddyxu May 18, 2023
719b500
print out progresss
eddyxu May 18, 2023
69eb129
open diskann index
eddyxu May 19, 2023
68d398d
cargo fmt
eddyxu May 19, 2023
53605c8
load on demand
eddyxu May 19, 2023
5405c2c
add debug
eddyxu May 19, 2023
6dce9da
pass dataset as Arc
eddyxu May 19, 2023
3f07468
pass dataset and load vectors on demand
eddyxu May 19, 2023
bbf82d6
buildable
eddyxu May 19, 2023
292b40b
change
eddyxu May 19, 2023
2a0b11e
fix tests
eddyxu May 19, 2023
3690f2c
clean up prints
eddyxu May 19, 2023
15d4ca6
revert metrics
eddyxu May 19, 2023
2a62d89
Update rust/src/index/vector/diskann/builder.rs
eddyxu May 19, 2023
73c9e5d
Merge branch 'main' into lei/search_diskann
eddyxu May 21, 2023
1797eaf
Merge branch 'main' into lei/search_diskann
eddyxu May 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions rust/src/index/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ use crate::{
index::{
pb::vector_index_stage::Stage,
vector::{
diskann::DiskANNParams,
diskann::{DiskANNIndex, DiskANNParams},
ivf::Ivf,
opq::{OPQIndex, OptimizedProductQuantizer},
pq::ProductQuantizer,
Expand Down Expand Up @@ -316,10 +316,11 @@ pub(crate) async fn build_vector_index(
}

/// Open the Vector index on dataset, specified by the `uuid`.
pub(crate) async fn open_index<'a>(
dataset: &'a Dataset,
pub(crate) async fn open_index(
Comment on lines 318 to +319
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the docs be updated for the new index?

Copy link
Contributor Author

@eddyxu eddyxu May 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This index is not ready for public usage yet. After this issue, we still need some work to profiling and optimizations.

This PR just make it ready for the team to start do benchmark e2e.

dataset: Arc<Dataset>,
column: &str,
uuid: &str,
) -> Result<Arc<dyn VectorIndex + 'a>> {
) -> Result<Arc<dyn VectorIndex>> {
if let Some(index) = dataset.session.index_cache.get(uuid) {
return Ok(index);
}
Expand Down Expand Up @@ -362,11 +363,6 @@ pub(crate) async fn open_index<'a>(
pb::index::Implementation::VectorIndex(vi) => vi,
};

let num_stages = vec_idx.stages.len();
if num_stages != 2 && num_stages != 3 {
return Err(Error::IO("Only support IVF_(O)PQ now".to_string()));
};

let metric_type = pb::VectorMetricType::from_i32(vec_idx.metric_type)
.ok_or(Error::Index(format!(
"Unsupported metric type value: {}",
Expand Down Expand Up @@ -430,6 +426,18 @@ pub(crate) async fn open_index<'a>(
let pq = Arc::new(ProductQuantizer::try_from(pq_proto).unwrap());
last_stage = Some(Arc::new(PQIndex::new(pq, metric_type)));
}
Some(Stage::Diskann(diskann_proto)) => {
if last_stage.is_some() {
return Err(Error::Index(format!(
"DiskANN should be the only stage, but we got stages: {:?}",
vec_idx.stages
)));
};
let graph_path = index_dir.child(diskann_proto.filename.as_str());
let diskann =
Arc::new(DiskANNIndex::try_new(dataset.clone(), column, &graph_path).await?);
last_stage = Some(diskann);
}
_ => {}
}
}
Expand Down
8 changes: 5 additions & 3 deletions rust/src/index/vector/diskann.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
///
/// Modified from diskann paper. The vector store is backed by the `lance` dataset.
mod builder;
mod row_vertex;
pub(crate) mod row_vertex;
mod search;

use super::{
Expand All @@ -25,6 +25,8 @@ use super::{
};
use crate::index::vector::pq::PQBuildParams;
pub(crate) use builder::build_diskann_index;
pub(crate) use row_vertex::RowVertex;
pub(crate) use search::DiskANNIndex;

#[derive(Clone, Debug)]
pub struct DiskANNParams {
Expand All @@ -48,9 +50,9 @@ pub struct DiskANNParams {
impl Default for DiskANNParams {
fn default() -> Self {
Self {
r: 90,
r: 50,
alpha: 1.2,
l: 100,
l: 70,
pq_params: PQBuildParams::default(),
metric_type: MetricType::L2,
}
Expand Down
47 changes: 22 additions & 25 deletions rust/src/index/vector/diskann/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
// limitations under the License.

use std::collections::{BinaryHeap, HashSet};
use std::sync::Arc;

use arrow_array::UInt32Array;
use arrow_array::{cast::as_primitive_array, types::UInt64Type};
use arrow_select::concat::concat_batches;
use futures::stream::{self, StreamExt, TryStreamExt};
Expand Down Expand Up @@ -74,7 +76,8 @@ pub(crate) async fn build_diskann_index(
println!("DiskANN: second pass: {}s", now.elapsed().as_secs_f32());

let index_dir = dataset.indices_dir().child(uuid);
let graph_file = index_dir.child("diskann_graph.lance");
let filename = "diskann_graph.lance";
let graph_file = index_dir.child(filename);

let mut write_params = WriteGraphParams::default();
write_params.batch_size = 2048 * 10;
Expand All @@ -95,7 +98,7 @@ pub(crate) async fn build_diskann_index(
name,
uuid,
graph.data.num_columns(),
graph_file.to_string().as_str(),
filename,
&[medoid],
params.metric_type,
&params,
Expand Down Expand Up @@ -153,7 +156,8 @@ async fn init_graph(
let distribution = Uniform::new(0, batch.num_rows());
// Randomly connect to r neighbors.
for i in 0..graph.len() {
let mut neighbor_ids: HashSet<u32> = graph.neighbors(i)?.iter().copied().collect();
let mut neighbor_ids: HashSet<u32> =
graph.neighbors(i).await?.values().iter().copied().collect();

while neighbor_ids.len() < r {
let neighbor_id = rng.sample(distribution);
Expand All @@ -162,17 +166,9 @@ async fn init_graph(
}
}

// Make bidirectional connections.
{
let n = graph.neighbors_mut(i);
n.clear();
n.extend(neighbor_ids.iter().copied());
// Release mutable borrow on graph.
}
{
for neighbor_id in neighbor_ids.iter() {
graph.add_neighbor(*neighbor_id as usize, i);
}
let new_neighbors = Arc::new(UInt32Array::from_iter(neighbor_ids.iter().copied()));
graph.set_neighbors(i, new_neighbors);
}
}

Expand All @@ -192,16 +188,16 @@ fn distance(matrix: &MatrixView, i: usize, j: usize) -> Result<f32> {
}

/// Algorithm 2 in the paper.
async fn robust_prune<V: Vertex + Clone>(
async fn robust_prune<V: Vertex + Clone + Sync + Send>(
graph: &GraphBuilder<V>,
id: usize,
mut visited: HashSet<usize>,
alpha: f32,
r: usize,
) -> Result<Vec<u32>> {
visited.remove(&id);
let neighbors = graph.neighbors(id)?;
visited.extend(neighbors.iter().map(|id| *id as usize));
let neighbors = graph.neighbors(id).await?;
visited.extend(neighbors.values().iter().map(|id| *id as usize));

let mut heap: BinaryHeap<VertexWithDistance> = visited
.iter()
Expand Down Expand Up @@ -266,7 +262,7 @@ async fn find_medoid(vectors: &MatrixView, metric_type: MetricType) -> Result<us
}

/// One pass of index building.
async fn index_once<V: Vertex + Clone>(
async fn index_once<V: Vertex + Clone + Sync + Send>(
graph: &mut GraphBuilder<V>,
medoid: usize,
alpha: f32,
Expand All @@ -283,20 +279,21 @@ async fn index_once<V: Vertex + Clone>(
.row(i)
.ok_or_else(|| Error::Index(format!("Cannot find vector with id {}", id)))?;

let state = greedy_search(graph, medoid, vector, 1, l)?;

graph
.neighbors_mut(id)
.extend(state.visited.iter().map(|id| *id as u32));
let state = greedy_search(graph, medoid, vector, 1, l).await?;

let neighbors = robust_prune(graph, id, state.visited, alpha, r).await?;
graph.set_neighbors(id, neighbors.to_vec());
graph.set_neighbors(
id,
Arc::new(neighbors.iter().copied().collect::<UInt32Array>()),
);

let fixed_graph: &GraphBuilder<V> = graph;
let neighbours = stream::iter(neighbors)
.map(|j| async move {
let mut neighbor_set: HashSet<usize> = fixed_graph
.neighbors(j as usize)?
.neighbors(j as usize)
.await?
.values()
.iter()
.map(|v| *v as usize)
.collect();
Expand All @@ -316,7 +313,7 @@ async fn index_once<V: Vertex + Clone>(
.try_collect::<Vec<_>>()
.await?;
for (j, nbs) in neighbours {
graph.set_neighbors(j, nbs);
graph.set_neighbors(j, Arc::new(nbs.into_iter().collect::<UInt32Array>()));
}
}

Expand Down
23 changes: 21 additions & 2 deletions rust/src/index/vector/diskann/row_vertex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::any::Any;

use arrow_array::Float32Array;
use byteorder::{ByteOrder, LE};

Expand All @@ -23,7 +25,6 @@ use crate::Result;
pub(crate) struct RowVertex {
pub(crate) row_id: u64,

#[allow(dead_code)]
pub(crate) vector: Option<Float32Array>,
}

Expand All @@ -33,10 +34,28 @@ impl RowVertex {
}
}

impl Vertex for RowVertex {}
impl Vertex for RowVertex {
fn vector(&self) -> &[f32] {
self.vector.as_ref().unwrap().values()
}

fn as_any(&self) -> &dyn Any {
self
}

fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}

pub(crate) struct RowVertexSerDe {}

impl RowVertexSerDe {
pub(crate) fn new() -> Self {
Self {}
}
}

impl VertexSerDe<RowVertex> for RowVertexSerDe {
fn size(&self) -> usize {
8
Expand Down
Loading