Skip to content

Commit

Permalink
perf: improve PQ computing distances (lancedb#3150)
Browse files Browse the repository at this point in the history
this is done by make the compiler know the size of distance table slice
```
5242880,L2,PQ=96,DIM=1536
                        time:   [148.44 ms 149.47 ms 150.50 ms]
                        change: [-53.716% -53.486% -53.252%] (p = 0.00 < 0.10)
                        Performance has improved.

5242880,Cosine,PQ=96,DIM=1536
                        time:   [191.84 ms 192.21 ms 192.75 ms]
                        change: [-46.738% -46.621% -46.461%] (p = 0.00 < 0.10)
                        Performance has improved.
```

---------

Signed-off-by: BubbleCal <[email protected]>
  • Loading branch information
BubbleCal authored Nov 22, 2024
1 parent bf2ce1f commit d79e870
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 38 deletions.
4 changes: 4 additions & 0 deletions rust/lance-index/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ harness = false
name = "pq_dist_table"
harness = false

[[bench]]
name = "4bitpq_dist_table"
harness = false

[[bench]]
name = "pq_assignment"
harness = false
Expand Down
64 changes: 64 additions & 0 deletions rust/lance-index/benches/4bitpq_dist_table.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

//! Benchmark of building PQ distance table.
use std::iter::repeat;

use arrow_array::types::Float32Type;
use arrow_array::{FixedSizeListArray, UInt8Array};
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use lance_arrow::FixedSizeListArrayExt;
use lance_index::vector::pq::ProductQuantizer;
use lance_linalg::distance::DistanceType;
use lance_testing::datagen::generate_random_array_with_seed;
use rand::{prelude::StdRng, Rng, SeedableRng};

#[cfg(target_os = "linux")]
use pprof::criterion::{Output, PProfProfiler};

const PQ: usize = 96;
const DIM: usize = 1536;
const TOTAL: usize = 16 * 1000;

fn dist_table(c: &mut Criterion) {
let codebook = generate_random_array_with_seed::<Float32Type>(256 * DIM, [88; 32]);
let query = generate_random_array_with_seed::<Float32Type>(DIM, [32; 32]);

let mut rnd = StdRng::from_seed([32; 32]);
let code = UInt8Array::from_iter_values(repeat(rnd.gen::<u8>()).take(TOTAL * PQ));

for dt in [DistanceType::L2, DistanceType::Cosine, DistanceType::Dot].iter() {
let pq = ProductQuantizer::new(
PQ,
4,
DIM,
FixedSizeListArray::try_new_from_values(codebook.clone(), DIM as i32).unwrap(),
*dt,
);

c.bench_function(
format!("{},{},PQ={},DIM={}", TOTAL, dt, PQ, DIM).as_str(),
|b| {
b.iter(|| {
black_box(pq.compute_distances(&query, &code).unwrap());
})
},
);
}
}

#[cfg(target_os = "linux")]
criterion_group!(
name=benches;
config = Criterion::default().significance_level(0.1).sample_size(10)
.with_profiler(PProfProfiler::new(100, Output::Flamegraph(None)));
targets = dist_table);

#[cfg(not(target_os = "linux"))]
criterion_group!(
name=benches;
config = Criterion::default().significance_level(0.1).sample_size(10);
targets = dist_table);

criterion_main!(benches);
51 changes: 18 additions & 33 deletions rust/lance-index/benches/pq_dist_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use pprof::criterion::{Output, PProfProfiler};

const PQ: usize = 96;
const DIM: usize = 1536;
const TOTAL: usize = 5 * 1024 * 1024;
const TOTAL: usize = 16 * 1000;

fn dist_table(c: &mut Criterion) {
let codebook = generate_random_array_with_seed::<Float32Type>(256 * DIM, [88; 32]);
Expand All @@ -28,39 +28,24 @@ fn dist_table(c: &mut Criterion) {
let mut rnd = StdRng::from_seed([32; 32]);
let code = UInt8Array::from_iter_values(repeat(rnd.gen::<u8>()).take(TOTAL * PQ));

let l2_pq = ProductQuantizer::new(
PQ,
8,
DIM,
FixedSizeListArray::try_new_from_values(codebook.clone(), DIM as i32).unwrap(),
DistanceType::L2,
);
for dt in [DistanceType::L2, DistanceType::Cosine, DistanceType::Dot].iter() {
let pq = ProductQuantizer::new(
PQ,
8,
DIM,
FixedSizeListArray::try_new_from_values(codebook.clone(), DIM as i32).unwrap(),
*dt,
);

c.bench_function(
format!("{},L2,PQ={},DIM={}", TOTAL, PQ, DIM).as_str(),
|b| {
b.iter(|| {
black_box(l2_pq.compute_distances(&query, &code).unwrap().len());
})
},
);

let cosine_pq = ProductQuantizer::new(
PQ,
8,
DIM,
FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap(),
DistanceType::Cosine,
);

c.bench_function(
format!("{},Cosine,PQ={},DIM={}", TOTAL, PQ, DIM).as_str(),
|b| {
b.iter(|| {
black_box(cosine_pq.compute_distances(&query, &code).unwrap());
})
},
);
c.bench_function(
format!("{},{},PQ={},DIM={}", TOTAL, dt, PQ, DIM).as_str(),
|b| {
b.iter(|| {
black_box(pq.compute_distances(&query, &code).unwrap());
})
},
);
}
}

#[cfg(target_os = "linux")]
Expand Down
19 changes: 14 additions & 5 deletions rust/lance-index/src/vector/pq/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ pub(super) fn compute_l2_distance(
// so code[i * num_vectors + j] is the code of i-th sub-vector of the j-th vector.
let num_vectors = code.len() / num_sub_vectors;
let mut distances = vec![0.0_f32; num_vectors];
let num_centroids = 2_usize.pow(num_bits);
// it must be 8
const NUM_CENTROIDS: usize = 2_usize.pow(8);
for (sub_vec_idx, vec_indices) in code.chunks_exact(num_vectors).enumerate() {
let dist_table = &distance_table[sub_vec_idx * num_centroids..];
let dist_table =
&distance_table[sub_vec_idx * NUM_CENTROIDS..(sub_vec_idx + 1) * NUM_CENTROIDS];
debug_assert_eq!(vec_indices.len(), distances.len());
vec_indices
.iter()
Expand All @@ -103,9 +105,16 @@ pub(super) fn compute_l2_distance_4bit(
) -> Vec<f32> {
let num_vectors = code.len() * 2 / num_sub_vectors;
let mut distances = vec![0.0_f32; num_vectors];
let num_centroids = 2_usize.pow(4);
const NUM_CENTROIDS: usize = 2_usize.pow(4);
for (sub_vec_idx, vec_indices) in code.chunks_exact(num_vectors).enumerate() {
let dist_table = &distance_table[sub_vec_idx * 2 * num_centroids..];
let dist_table: &[f32; NUM_CENTROIDS] = &distance_table
[sub_vec_idx * 2 * NUM_CENTROIDS..(sub_vec_idx * 2 + 1) * NUM_CENTROIDS]
.try_into()
.unwrap();
let dist_table_next: &[f32; NUM_CENTROIDS] = &distance_table
[(sub_vec_idx * 2 + 1) * NUM_CENTROIDS..(sub_vec_idx * 2 + 2) * NUM_CENTROIDS]
.try_into()
.unwrap();
debug_assert_eq!(vec_indices.len(), distances.len());
vec_indices
.iter()
Expand All @@ -115,7 +124,7 @@ pub(super) fn compute_l2_distance_4bit(
let current_idx = centroid_idx & 0xF;
let next_idx = centroid_idx >> 4;
*sum += dist_table[current_idx as usize];
*sum += dist_table[num_centroids + next_idx as usize];
*sum += dist_table_next[next_idx as usize];
});
}

Expand Down

0 comments on commit d79e870

Please sign in to comment.