Skip to content

Commit

Permalink
perf: avoid copying of creating memory dist calculator (lancedb#2219)
Browse files Browse the repository at this point in the history
Signed-off-by: BubbleCal <[email protected]>
  • Loading branch information
BubbleCal authored Apr 23, 2024
1 parent 99008c6 commit 1e08425
Show file tree
Hide file tree
Showing 13 changed files with 52 additions and 43 deletions.
2 changes: 1 addition & 1 deletion rust/lance-index/benches/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ fn bench_hnsw(c: &mut Criterion) {
.await
.unwrap();
let uids: HashSet<u32> = hnsw
.search(query, K, 300, None)
.search(query.clone(), K, 300, None)
.unwrap()
.iter()
.map(|node| node.id)
Expand Down
20 changes: 11 additions & 9 deletions rust/lance-index/src/vector/graph/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
use std::sync::Arc;

use super::storage::{DistCalculator, VectorStorage};
use arrow::array::AsArray;
use arrow_array::types::Float32Type;
use arrow_array::ArrayRef;
use lance_linalg::{distance::MetricType, MatrixView};

/// All data are stored in memory
Expand All @@ -26,7 +28,7 @@ impl InMemoryVectorStorage {
}
}

pub fn vector(&self, id: u32) -> &[f32] {
pub fn vector(&self, id: u32) -> ArrayRef {
self.vectors.row(id as usize).unwrap()
}
}
Expand All @@ -48,40 +50,40 @@ impl VectorStorage for InMemoryVectorStorage {
self.metric_type
}

fn dist_calculator(&self, query: &[f32]) -> Box<dyn DistCalculator> {
fn dist_calculator(&self, query: ArrayRef) -> Box<dyn DistCalculator> {
Box::new(InMemoryDistanceCal {
vectors: self.vectors.clone(),
query: query.to_vec(),
query,
metric_type: self.metric_type,
})
}

fn dist_calculator_from_id(&self, id: u32) -> Box<dyn DistCalculator> {
Box::new(InMemoryDistanceCal {
vectors: self.vectors.clone(),
query: self.vectors.row(id as usize).unwrap().to_vec(),
query: self.vectors.row(id as usize).unwrap(),
metric_type: self.metric_type,
})
}

/// Distance between two vectors.
fn distance_between(&self, a: u32, b: u32) -> f32 {
let vector1 = self.vectors.row(a as usize).unwrap();
let vector2 = self.vectors.row(b as usize).unwrap();
let vector1 = self.vectors.row_ref(a as usize).unwrap();
let vector2 = self.vectors.row_ref(b as usize).unwrap();
self.metric_type.func()(vector1, vector2)
}
}

struct InMemoryDistanceCal {
vectors: Arc<MatrixView<Float32Type>>,
query: Vec<f32>,
query: ArrayRef,
metric_type: MetricType,
}

impl DistCalculator for InMemoryDistanceCal {
#[inline]
fn distance(&self, id: u32) -> f32 {
let vector = self.vectors.row(id as usize).unwrap();
self.metric_type.func()(&self.query, vector)
let vector = self.vectors.row_ref(id as usize).unwrap();
self.metric_type.func()(self.query.as_primitive::<Float32Type>().values(), vector)
}
}
3 changes: 2 additions & 1 deletion rust/lance-index/src/vector/graph/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

use std::any::Any;

use arrow_array::ArrayRef;
use lance_linalg::distance::MetricType;

pub trait DistCalculator {
Expand Down Expand Up @@ -35,7 +36,7 @@ pub trait VectorStorage: Send + Sync {
///
/// Using dist calcualtor can be more efficient as it can pre-compute some
/// values.
fn dist_calculator(&self, query: &[f32]) -> Box<dyn DistCalculator>;
fn dist_calculator(&self, query: ArrayRef) -> Box<dyn DistCalculator>;

fn dist_calculator_from_id(&self, id: u32) -> Box<dyn DistCalculator>;

Expand Down
9 changes: 5 additions & 4 deletions rust/lance-index/src/vector/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::ops::Range;
use std::sync::Arc;

use arrow::datatypes::UInt32Type;
use arrow_array::ArrayRef;
use arrow_array::{
builder::{ListBuilder, UInt32Builder},
cast::AsArray,
Expand Down Expand Up @@ -347,7 +348,7 @@ impl HNSW {
/// A list of `(id_in_graph, distance)` pairs. Or Error if the search failed.
pub fn search(
&self,
query: &[f32],
query: ArrayRef,
k: usize,
ef: usize,
bitset: Option<RoaringBitmap>,
Expand Down Expand Up @@ -593,7 +594,7 @@ mod tests {
fn ground_truth(mat: &MatrixView<Float32Type>, query: &[f32], k: usize) -> HashSet<u32> {
let mut dists = vec![];
for i in 0..mat.num_rows() {
let dist = lance_linalg::distance::l2_distance(query, mat.row(i).unwrap());
let dist = lance_linalg::distance::l2_distance(query, mat.row_ref(i).unwrap());
dists.push((dist, i as u32));
}
dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
Expand Down Expand Up @@ -625,12 +626,12 @@ mod tests {
.unwrap();

let results: HashSet<u32> = hnsw
.search(q, K, 128, None)
.search(q.clone(), K, 128, None)
.unwrap()
.iter()
.map(|node| node.id)
.collect();
let gt = ground_truth(&mat, q, K);
let gt = ground_truth(&mat, q.as_primitive::<Float32Type>().values(), K);
let recall = results.intersection(&gt).count() as f32 / K as f32;
assert!(recall >= 0.9, "Recall: {}", recall);
}
Expand Down
2 changes: 1 addition & 1 deletion rust/lance-index/src/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ impl<T: ArrowFloatType + Dot + L2 + ArrowPrimitiveType> Ivf for IvfImpl<T> {
.chunks_exact(dim)
.zip(part_ids.values())
.flat_map(|(vector, &part_id)| {
let centroid = self.centroids.row(part_id as usize).unwrap();
let centroid = self.centroids.row_ref(part_id as usize).unwrap();
vector.iter().zip(centroid.iter()).map(|(&v, &c)| v - c)
})
.collect::<Vec<_>>();
Expand Down
5 changes: 3 additions & 2 deletions rust/lance-index/src/vector/pq/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
use std::{cmp::min, collections::HashMap, sync::Arc};

use arrow_array::ArrayRef;
use arrow_array::{
cast::AsArray,
types::{Float32Type, UInt64Type, UInt8Type},
Expand Down Expand Up @@ -410,13 +411,13 @@ impl VectorStorage for ProductQuantizationStorage {
self.metric_type
}

fn dist_calculator(&self, query: &[f32]) -> Box<dyn DistCalculator> {
fn dist_calculator(&self, query: ArrayRef) -> Box<dyn DistCalculator> {
Box::new(PQDistCalculator::new(
self.codebook.values(),
self.num_bits,
self.num_sub_vectors,
self.pq_code.clone(),
query,
query.as_primitive::<Float32Type>().values(),
self.metric_type(),
))
}
Expand Down
2 changes: 1 addition & 1 deletion rust/lance-index/src/vector/pq/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub(super) fn divide_to_subvectors<T: ArrowFloatType>(
for i in 0..m {
let mut builder = Vec::with_capacity(capacity);
for j in 0..data.num_rows() {
let row = data.row(j).unwrap();
let row = data.row_ref(j).unwrap();
let start = i * sub_vector_length;
builder.extend_from_slice(&row[start..start + sub_vector_length]);
}
Expand Down
2 changes: 1 addition & 1 deletion rust/lance-index/src/vector/residual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ impl<T: ArrowFloatType> Transformer for ResidualTransform<T> {
.chunks_exact(dim as usize)
.zip(part_ids.as_primitive::<UInt32Type>().values().iter())
.for_each(|(vector, &part_id)| {
let centroid = self.centroids.row(part_id as usize).unwrap();
let centroid = self.centroids.row_ref(part_id as usize).unwrap();
// TODO: SIMD
residual_arr.extend(
vector
Expand Down
13 changes: 5 additions & 8 deletions rust/lance-index/src/vector/sq/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use std::{ops::Range, sync::Arc};

use arrow::{array::AsArray, datatypes::Float32Type};
use arrow_array::{Array, FixedSizeListArray, RecordBatch, UInt64Array, UInt8Array};
use arrow_array::{Array, ArrayRef, FixedSizeListArray, RecordBatch, UInt64Array, UInt8Array};
use async_trait::async_trait;
use lance_core::{Error, Result, ROW_ID};
use lance_file::reader::FileReader;
Expand Down Expand Up @@ -212,7 +212,7 @@ impl VectorStorage for ScalarQuantizationStorage {
///
/// Using dist calcualtor can be more efficient as it can pre-compute some
/// values.
fn dist_calculator(&self, query: &[f32]) -> Box<dyn DistCalculator> {
fn dist_calculator(&self, query: ArrayRef) -> Box<dyn DistCalculator> {
Box::new(SQDistCalculator::new(
query,
self.sq_codes.clone(),
Expand Down Expand Up @@ -243,12 +243,9 @@ struct SQDistCalculator {
}

impl SQDistCalculator {
fn new(query: &[f32], sq_codes: Arc<FixedSizeListArray>, bounds: Range<f64>) -> Self {
// TODO: support f16/f64
let query_sq_code = scale_to_u8::<Float32Type>(query, bounds)
.into_iter()
.collect::<Vec<_>>();

fn new(query: ArrayRef, sq_codes: Arc<FixedSizeListArray>, bounds: Range<f64>) -> Self {
let query_sq_code =
scale_to_u8::<Float32Type>(query.as_primitive::<Float32Type>().values(), bounds);
Self {
query_sq_code,
sq_codes,
Expand Down
19 changes: 16 additions & 3 deletions rust/lance-linalg/src/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
use std::sync::Arc;

use arrow_array::{Array, ArrowPrimitiveType, FixedSizeListArray};
use arrow_array::{Array, ArrayRef, ArrowPrimitiveType, FixedSizeListArray};
use arrow_schema::{ArrowError, DataType};
use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray, FloatType};
use num_traits::{AsPrimitive, Float, FromPrimitive, ToPrimitive};
Expand Down Expand Up @@ -134,7 +134,7 @@ impl<T: ArrowFloatType> MatrixView<T> {
/// Returns a row at index `i`. Returns `None` if the index is out of bound.
///
/// # Panics if the matrix is transposed.
pub fn row(&self, i: usize) -> Option<&[T::Native]> {
pub fn row_ref(&self, i: usize) -> Option<&[T::Native]> {
assert!(
!self.transpose,
"Centroid is not defined for transposed matrix."
Expand All @@ -147,6 +147,19 @@ impl<T: ArrowFloatType> MatrixView<T> {
}
}

pub fn row(&self, i: usize) -> Option<ArrayRef> {
assert!(
!self.transpose,
"Centroid is not defined for transposed matrix."
);
if i >= self.num_rows() {
None
} else {
let dim = self.num_columns();
Some(self.data.slice(i * dim, dim))
}
}

/// Compute the centroid from all the rows. Returns `None` if this matrix is empty.
///
/// # Panics if the matrix is transposed.
Expand Down Expand Up @@ -359,7 +372,7 @@ impl<'a, T: ArrowFloatType> Iterator for MatrixRowIter<'a, T> {
fn next(&mut self) -> Option<Self::Item> {
let cur_idx = self.cur_idx;
self.cur_idx += 1;
self.data.row(cur_idx)
self.data.row_ref(cur_idx)
}
}

Expand Down
6 changes: 3 additions & 3 deletions rust/lance/examples/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct Args {
fn ground_truth(mat: &MatrixView<Float32Type>, query: &[f32], k: usize) -> HashSet<u32> {
let mut dists = vec![];
for i in 0..mat.num_rows() {
let dist = lance_linalg::distance::l2_distance(query, mat.row(i).unwrap());
let dist = lance_linalg::distance::l2_distance(query, mat.row_ref(i).unwrap());
dists.push((dist, i as u32));
}
dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
Expand Down Expand Up @@ -80,7 +80,7 @@ async fn main() {

let q = mat.row(0).unwrap();
let k = 10;
let gt = ground_truth(&mat, q, k);
let gt = ground_truth(&mat, q.as_primitive::<Float32Type>().values(), k);

for ef_construction in [15, 30, 50] {
let now = std::time::Instant::now();
Expand All @@ -98,7 +98,7 @@ async fn main() {
let construct_time = now.elapsed().as_secs_f32();
let now = std::time::Instant::now();
let results: HashSet<u32> = hnsw
.search(q, k, args.ef, None)
.search(q.clone(), k, args.ef, None)
.unwrap()
.iter()
.map(|node| node.id)
Expand Down
10 changes: 2 additions & 8 deletions rust/lance/src/index/vector/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ use std::{
sync::Arc,
};

use arrow_array::{cast::AsArray, types::Float32Type, Float32Array, RecordBatch, UInt64Array};
use arrow_array::{Float32Array, RecordBatch, UInt64Array};

use arrow_schema::DataType;
use async_trait::async_trait;
use lance_arrow::*;
use lance_core::{datatypes::Schema, Error, Result, ROW_ID};
use lance_file::reader::FileReader;
use lance_index::{
Expand Down Expand Up @@ -164,12 +163,7 @@ impl<Q: Quantization + Send + Sync + 'static> VectorIndex for HNSWIndex<Q> {
});
}

let results = self.hnsw.search(
query.key.as_primitive::<Float32Type>().as_slice(),
k,
ef,
bitmap,
)?;
let results = self.hnsw.search(query.key.clone(), k, ef, bitmap)?;

let row_ids = UInt64Array::from_iter_values(results.iter().map(|x| row_ids[x.id as usize]));
let distances = Arc::new(Float32Array::from_iter_values(
Expand Down
2 changes: 1 addition & 1 deletion rust/lance/src/index/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2210,7 +2210,7 @@ mod tests {
fn ground_truth(mat: &MatrixView<Float32Type>, query: &[f32], k: usize) -> HashSet<u32> {
let mut dists = vec![];
for i in 0..mat.num_rows() {
let dist = lance_linalg::distance::l2_distance(query, mat.row(i).unwrap());
let dist = lance_linalg::distance::l2_distance(query, mat.row_ref(i).unwrap());
dists.push((dist, i as u32));
}
dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
Expand Down

0 comments on commit 1e08425

Please sign in to comment.