diff --git a/rust/lance-index/benches/pq_dist_table.rs b/rust/lance-index/benches/pq_dist_table.rs index 8e1c49c4a0..515a309a0f 100644 --- a/rust/lance-index/benches/pq_dist_table.rs +++ b/rust/lance-index/benches/pq_dist_table.rs @@ -9,6 +9,7 @@ use arrow_array::types::Float32Type; use arrow_array::{FixedSizeListArray, UInt8Array}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use lance_arrow::FixedSizeListArrayExt; +use lance_index::vector::pq::distance::*; use lance_index::vector::pq::ProductQuantizer; use lance_linalg::distance::DistanceType; use lance_testing::datagen::generate_random_array_with_seed; @@ -21,7 +22,52 @@ const PQ: usize = 96; const DIM: usize = 1536; const TOTAL: usize = 16 * 1000; -fn dist_table(c: &mut Criterion) { +fn construct_dist_table(c: &mut Criterion) { + let codebook = generate_random_array_with_seed::(256 * DIM, [88; 32]); + let query = generate_random_array_with_seed::(DIM, [32; 32]); + + c.bench_function( + format!( + "construct_dist_table: {},PQ={},DIM={}", + DistanceType::L2, + PQ, + DIM + ) + .as_str(), + |b| { + b.iter(|| { + black_box(build_distance_table_l2( + codebook.values(), + 8, + PQ, + query.values(), + )); + }) + }, + ); + + c.bench_function( + format!( + "construct_dist_table: {},PQ={},DIM={}", + DistanceType::Dot, + PQ, + DIM + ) + .as_str(), + |b| { + b.iter(|| { + black_box(build_distance_table_dot( + codebook.values(), + 8, + PQ, + query.values(), + )); + }) + }, + ); +} + +fn compute_distances(c: &mut Criterion) { let codebook = generate_random_array_with_seed::(256 * DIM, [88; 32]); let query = generate_random_array_with_seed::(DIM, [32; 32]); @@ -38,7 +84,7 @@ fn dist_table(c: &mut Criterion) { ); c.bench_function( - format!("{},{},PQ={},DIM={}", TOTAL, dt, PQ, DIM).as_str(), + format!("compute_distances: {},{},PQ={},DIM={}", TOTAL, dt, PQ, DIM).as_str(), |b| { b.iter(|| { black_box(pq.compute_distances(&query, &code).unwrap()); @@ -53,12 +99,12 @@ criterion_group!( name=benches; config = Criterion::default().significance_level(0.1).sample_size(10) .with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); - targets = dist_table); + targets = construct_dist_table, compute_distances); #[cfg(not(target_os = "linux"))] criterion_group!( name=benches; config = Criterion::default().significance_level(0.1).sample_size(10); - targets = dist_table); + targets = construct_dist_table, compute_distances); criterion_main!(benches); diff --git a/rust/lance-index/src/vector/pq.rs b/rust/lance-index/src/vector/pq.rs index ad0147dd3a..467599157b 100644 --- a/rust/lance-index/src/vector/pq.rs +++ b/rust/lance-index/src/vector/pq.rs @@ -23,7 +23,7 @@ use storage::{ProductQuantizationMetadata, ProductQuantizationStorage, PQ_METADA use tracing::instrument; pub mod builder; -mod distance; +pub mod distance; pub mod storage; pub mod transform; pub(crate) mod utils; @@ -96,6 +96,26 @@ impl ProductQuantizer { #[instrument(name = "ProductQuantizer::transform", level = "debug", skip_all)] fn transform(&self, vectors: &dyn Array) -> Result + where + T::Native: Float + L2 + Dot, + { + match self.num_bits { + 4 => self.transform_impl::<4, T>(vectors), + 8 => self.transform_impl::<8, T>(vectors), + _ => Err(Error::Index { + message: format!( + "ProductQuantization: num_bits {} not supported", + self.num_bits + ), + location: location!(), + }), + } + } + + fn transform_impl( + &self, + vectors: &dyn Array, + ) -> Result where T::Native: Float + L2 + Dot, { @@ -108,8 +128,7 @@ impl ProductQuantizer { })?; let num_sub_vectors = self.num_sub_vectors; let dim = self.dimension; - let num_bits = self.num_bits; - if num_bits == 4 && num_sub_vectors % 2 != 0 { + if NUM_BITS == 4 && num_sub_vectors % 2 != 0 { return Err(Error::Index { message: format!( "PQ: num_sub_vectors must be divisible by 2 for num_bits=4, but got {}", @@ -132,17 +151,16 @@ impl ProductQuantizer { .chunks_exact(sub_dim) .enumerate() .map(|(sub_idx, sub_vector)| { - let centroids = get_sub_vector_centroids( + let centroids = get_sub_vector_centroids::( codebook.values(), dim, - num_bits, num_sub_vectors, sub_idx, ); compute_partition(centroids, sub_vector, distance_type).unwrap() as u8 }) .collect::>(); - if num_bits == 4 { + if NUM_BITS == 4 { sub_vec_code .chunks_exact(2) .map(|v| (v[1] << 4) | v[0]) @@ -153,7 +171,7 @@ impl ProductQuantizer { }) .collect::>(); - let num_sub_vectors_in_byte = if num_bits == 4 { + let num_sub_vectors_in_byte = if NUM_BITS == 4 { num_sub_vectors / 2 } else { num_sub_vectors @@ -321,13 +339,24 @@ impl ProductQuantizer { /// /// Returns a flatten `num_centroids * sub_vector_width` f32 array. pub fn centroids(&self, sub_vector_idx: usize) -> &[T::Native] { - get_sub_vector_centroids( - self.codebook.values().as_primitive::().values(), - self.dimension, - self.num_bits, - self.num_sub_vectors, - sub_vector_idx, - ) + match self.num_bits { + 4 => get_sub_vector_centroids::<4, _>( + self.codebook.values().as_primitive::().values(), + self.dimension, + self.num_sub_vectors, + sub_vector_idx, + ), + 8 => get_sub_vector_centroids::<8, _>( + self.codebook.values().as_primitive::().values(), + self.dimension, + self.num_sub_vectors, + sub_vector_idx, + ), + _ => panic!( + "ProductQuantization: num_bits {} not supported", + self.num_bits + ), + } } } diff --git a/rust/lance-index/src/vector/pq/distance.rs b/rust/lance-index/src/vector/pq/distance.rs index 6d66d020e3..ddf98b099c 100644 --- a/rust/lance-index/src/vector/pq/distance.rs +++ b/rust/lance-index/src/vector/pq/distance.rs @@ -1,52 +1,82 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use core::panic; use std::cmp::min; use lance_linalg::distance::{dot_distance_batch, l2_distance_batch, Dot, L2}; +use lance_table::utils::LanceIteratorExtension; use super::{num_centroids, utils::get_sub_vector_centroids}; /// Build a Distance Table from the query to each PQ centroid /// using L2 distance. -pub(super) fn build_distance_table_l2( +pub fn build_distance_table_l2( codebook: &[T], num_bits: u32, num_sub_vectors: usize, query: &[T], ) -> Vec { - let dimension = query.len(); + match num_bits { + 4 => build_distance_table_l2_impl::<4, T>(codebook, num_sub_vectors, query), + 8 => build_distance_table_l2_impl::<8, T>(codebook, num_sub_vectors, query), + _ => panic!("Unsupported number of bits: {}", num_bits), + } +} +#[inline] +pub fn build_distance_table_l2_impl( + codebook: &[T], + num_sub_vectors: usize, + query: &[T], +) -> Vec { + let dimension = query.len(); let sub_vector_length = dimension / num_sub_vectors; + let num_centroids = 2_usize.pow(NUM_BITS); query .chunks_exact(sub_vector_length) .enumerate() .flat_map(|(i, sub_vec)| { let subvec_centroids = - get_sub_vector_centroids(codebook, dimension, num_bits, num_sub_vectors, i); + get_sub_vector_centroids::(codebook, dimension, num_sub_vectors, i); l2_distance_batch(sub_vec, subvec_centroids, sub_vector_length) }) + .exact_size(num_sub_vectors * num_centroids) .collect() } /// Build a Distance Table from the query to each PQ centroid /// using Dot distance. -pub(super) fn build_distance_table_dot( +pub fn build_distance_table_dot( codebook: &[T], num_bits: u32, num_sub_vectors: usize, query: &[T], +) -> Vec { + match num_bits { + 4 => build_distance_table_dot_impl::<4, T>(codebook, num_sub_vectors, query), + 8 => build_distance_table_dot_impl::<8, T>(codebook, num_sub_vectors, query), + _ => panic!("Unsupported number of bits: {}", num_bits), + } +} + +pub fn build_distance_table_dot_impl( + codebook: &[T], + num_sub_vectors: usize, + query: &[T], ) -> Vec { let dimension = query.len(); let sub_vector_length = dimension / num_sub_vectors; + let num_centroids = 2_usize.pow(NUM_BITS); query .chunks_exact(sub_vector_length) .enumerate() .flat_map(|(i, sub_vec)| { let subvec_centroids = - get_sub_vector_centroids(codebook, dimension, num_bits, num_sub_vectors, i); + get_sub_vector_centroids::(codebook, dimension, num_sub_vectors, i); dot_distance_batch(sub_vec, subvec_centroids, sub_vector_length) }) + .exact_size(num_sub_vectors * num_centroids) .collect() } diff --git a/rust/lance-index/src/vector/pq/utils.rs b/rust/lance-index/src/vector/pq/utils.rs index 7e73a1b9f4..8766eb8005 100644 --- a/rust/lance-index/src/vector/pq/utils.rs +++ b/rust/lance-index/src/vector/pq/utils.rs @@ -51,21 +51,20 @@ pub fn num_centroids(num_bits: impl Into) -> usize { } #[inline] -pub fn get_sub_vector_centroids( +pub fn get_sub_vector_centroids( codebook: &[T], dimension: usize, - num_bits: impl Into, num_sub_vectors: usize, sub_vector_idx: usize, ) -> &[T] { - assert!( + debug_assert!( sub_vector_idx < num_sub_vectors, "sub_vector idx: {}, num_sub_vectors: {}", sub_vector_idx, num_sub_vectors ); - let num_centroids = num_centroids(num_bits); + let num_centroids: usize = 2_usize.pow(NUM_BITS); let sub_vector_width = dimension / num_sub_vectors; &codebook[sub_vector_idx * num_centroids * sub_vector_width ..(sub_vector_idx + 1) * num_centroids * sub_vector_width] diff --git a/rust/lance-linalg/src/distance/l2.rs b/rust/lance-linalg/src/distance/l2.rs index a1c27a74b7..f3c98be709 100644 --- a/rust/lance-linalg/src/distance/l2.rs +++ b/rust/lance-linalg/src/distance/l2.rs @@ -242,7 +242,7 @@ pub fn l2_distance_batch<'a, T: L2>( debug_assert_eq!(from.len(), dimension); debug_assert_eq!(to.len() % dimension, 0); - Box::new(T::l2_batch(from, to, dimension)) + T::l2_batch(from, to, dimension) } fn do_l2_distance_arrow_batch(