From d79e870628b8d5f78cf018e6fd184c503c364b1e Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Fri, 22 Nov 2024 11:37:55 +0800 Subject: [PATCH] perf: improve PQ computing distances (#3150) this is done by make the compiler know the size of distance table slice ``` 5242880,L2,PQ=96,DIM=1536 time: [148.44 ms 149.47 ms 150.50 ms] change: [-53.716% -53.486% -53.252%] (p = 0.00 < 0.10) Performance has improved. 5242880,Cosine,PQ=96,DIM=1536 time: [191.84 ms 192.21 ms 192.75 ms] change: [-46.738% -46.621% -46.461%] (p = 0.00 < 0.10) Performance has improved. ``` --------- Signed-off-by: BubbleCal --- rust/lance-index/Cargo.toml | 4 ++ rust/lance-index/benches/4bitpq_dist_table.rs | 64 +++++++++++++++++++ rust/lance-index/benches/pq_dist_table.rs | 51 ++++++--------- rust/lance-index/src/vector/pq/distance.rs | 19 ++++-- 4 files changed, 100 insertions(+), 38 deletions(-) create mode 100644 rust/lance-index/benches/4bitpq_dist_table.rs diff --git a/rust/lance-index/Cargo.toml b/rust/lance-index/Cargo.toml index fb23a0fb45..12d38e5678 100644 --- a/rust/lance-index/Cargo.toml +++ b/rust/lance-index/Cargo.toml @@ -82,6 +82,10 @@ harness = false name = "pq_dist_table" harness = false +[[bench]] +name = "4bitpq_dist_table" +harness = false + [[bench]] name = "pq_assignment" harness = false diff --git a/rust/lance-index/benches/4bitpq_dist_table.rs b/rust/lance-index/benches/4bitpq_dist_table.rs new file mode 100644 index 0000000000..251159130a --- /dev/null +++ b/rust/lance-index/benches/4bitpq_dist_table.rs @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Benchmark of building PQ distance table. + +use std::iter::repeat; + +use arrow_array::types::Float32Type; +use arrow_array::{FixedSizeListArray, UInt8Array}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use lance_arrow::FixedSizeListArrayExt; +use lance_index::vector::pq::ProductQuantizer; +use lance_linalg::distance::DistanceType; +use lance_testing::datagen::generate_random_array_with_seed; +use rand::{prelude::StdRng, Rng, SeedableRng}; + +#[cfg(target_os = "linux")] +use pprof::criterion::{Output, PProfProfiler}; + +const PQ: usize = 96; +const DIM: usize = 1536; +const TOTAL: usize = 16 * 1000; + +fn dist_table(c: &mut Criterion) { + let codebook = generate_random_array_with_seed::(256 * DIM, [88; 32]); + let query = generate_random_array_with_seed::(DIM, [32; 32]); + + let mut rnd = StdRng::from_seed([32; 32]); + let code = UInt8Array::from_iter_values(repeat(rnd.gen::()).take(TOTAL * PQ)); + + for dt in [DistanceType::L2, DistanceType::Cosine, DistanceType::Dot].iter() { + let pq = ProductQuantizer::new( + PQ, + 4, + DIM, + FixedSizeListArray::try_new_from_values(codebook.clone(), DIM as i32).unwrap(), + *dt, + ); + + c.bench_function( + format!("{},{},PQ={},DIM={}", TOTAL, dt, PQ, DIM).as_str(), + |b| { + b.iter(|| { + black_box(pq.compute_distances(&query, &code).unwrap()); + }) + }, + ); + } +} + +#[cfg(target_os = "linux")] +criterion_group!( + name=benches; + config = Criterion::default().significance_level(0.1).sample_size(10) + .with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); + targets = dist_table); + +#[cfg(not(target_os = "linux"))] +criterion_group!( + name=benches; + config = Criterion::default().significance_level(0.1).sample_size(10); + targets = dist_table); + +criterion_main!(benches); diff --git a/rust/lance-index/benches/pq_dist_table.rs b/rust/lance-index/benches/pq_dist_table.rs index 227112644d..8e1c49c4a0 100644 --- a/rust/lance-index/benches/pq_dist_table.rs +++ b/rust/lance-index/benches/pq_dist_table.rs @@ -19,7 +19,7 @@ use pprof::criterion::{Output, PProfProfiler}; const PQ: usize = 96; const DIM: usize = 1536; -const TOTAL: usize = 5 * 1024 * 1024; +const TOTAL: usize = 16 * 1000; fn dist_table(c: &mut Criterion) { let codebook = generate_random_array_with_seed::(256 * DIM, [88; 32]); @@ -28,39 +28,24 @@ fn dist_table(c: &mut Criterion) { let mut rnd = StdRng::from_seed([32; 32]); let code = UInt8Array::from_iter_values(repeat(rnd.gen::()).take(TOTAL * PQ)); - let l2_pq = ProductQuantizer::new( - PQ, - 8, - DIM, - FixedSizeListArray::try_new_from_values(codebook.clone(), DIM as i32).unwrap(), - DistanceType::L2, - ); + for dt in [DistanceType::L2, DistanceType::Cosine, DistanceType::Dot].iter() { + let pq = ProductQuantizer::new( + PQ, + 8, + DIM, + FixedSizeListArray::try_new_from_values(codebook.clone(), DIM as i32).unwrap(), + *dt, + ); - c.bench_function( - format!("{},L2,PQ={},DIM={}", TOTAL, PQ, DIM).as_str(), - |b| { - b.iter(|| { - black_box(l2_pq.compute_distances(&query, &code).unwrap().len()); - }) - }, - ); - - let cosine_pq = ProductQuantizer::new( - PQ, - 8, - DIM, - FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap(), - DistanceType::Cosine, - ); - - c.bench_function( - format!("{},Cosine,PQ={},DIM={}", TOTAL, PQ, DIM).as_str(), - |b| { - b.iter(|| { - black_box(cosine_pq.compute_distances(&query, &code).unwrap()); - }) - }, - ); + c.bench_function( + format!("{},{},PQ={},DIM={}", TOTAL, dt, PQ, DIM).as_str(), + |b| { + b.iter(|| { + black_box(pq.compute_distances(&query, &code).unwrap()); + }) + }, + ); + } } #[cfg(target_os = "linux")] diff --git a/rust/lance-index/src/vector/pq/distance.rs b/rust/lance-index/src/vector/pq/distance.rs index 9a3d876342..6d66d020e3 100644 --- a/rust/lance-index/src/vector/pq/distance.rs +++ b/rust/lance-index/src/vector/pq/distance.rs @@ -80,9 +80,11 @@ pub(super) fn compute_l2_distance( // so code[i * num_vectors + j] is the code of i-th sub-vector of the j-th vector. let num_vectors = code.len() / num_sub_vectors; let mut distances = vec![0.0_f32; num_vectors]; - let num_centroids = 2_usize.pow(num_bits); + // it must be 8 + const NUM_CENTROIDS: usize = 2_usize.pow(8); for (sub_vec_idx, vec_indices) in code.chunks_exact(num_vectors).enumerate() { - let dist_table = &distance_table[sub_vec_idx * num_centroids..]; + let dist_table = + &distance_table[sub_vec_idx * NUM_CENTROIDS..(sub_vec_idx + 1) * NUM_CENTROIDS]; debug_assert_eq!(vec_indices.len(), distances.len()); vec_indices .iter() @@ -103,9 +105,16 @@ pub(super) fn compute_l2_distance_4bit( ) -> Vec { let num_vectors = code.len() * 2 / num_sub_vectors; let mut distances = vec![0.0_f32; num_vectors]; - let num_centroids = 2_usize.pow(4); + const NUM_CENTROIDS: usize = 2_usize.pow(4); for (sub_vec_idx, vec_indices) in code.chunks_exact(num_vectors).enumerate() { - let dist_table = &distance_table[sub_vec_idx * 2 * num_centroids..]; + let dist_table: &[f32; NUM_CENTROIDS] = &distance_table + [sub_vec_idx * 2 * NUM_CENTROIDS..(sub_vec_idx * 2 + 1) * NUM_CENTROIDS] + .try_into() + .unwrap(); + let dist_table_next: &[f32; NUM_CENTROIDS] = &distance_table + [(sub_vec_idx * 2 + 1) * NUM_CENTROIDS..(sub_vec_idx * 2 + 2) * NUM_CENTROIDS] + .try_into() + .unwrap(); debug_assert_eq!(vec_indices.len(), distances.len()); vec_indices .iter() @@ -115,7 +124,7 @@ pub(super) fn compute_l2_distance_4bit( let current_idx = centroid_idx & 0xF; let next_idx = centroid_idx >> 4; *sum += dist_table[current_idx as usize]; - *sum += dist_table[num_centroids + next_idx as usize]; + *sum += dist_table_next[next_idx as usize]; }); }