Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: in-register lookup table & SIMD for 4bit PQ #3178

Merged
merged 30 commits into from
Dec 5, 2024

Conversation

BubbleCal
Copy link
Contributor

@BubbleCal BubbleCal commented Nov 26, 2024

4bit PQ is 3x faster than before:

16000,l2,PQ=96x4,DIM=1536
                        time:   [187.17 µs 187.95 µs 188.52 µs]
                        change: [-65.789% -65.641% -65.520%] (p = 0.00 < 0.10)
                        Performance has improved.

16000,cosine,PQ=96x4,DIM=1536
                        time:   [214.16 µs 214.52 µs 214.89 µs]
                        change: [-62.748% -62.594% -62.442%] (p = 0.00 < 0.10)
                        Performance has improved.

16000,dot,PQ=96x4,DIM=1536
                        time:   [190.12 µs 191.27 µs 192.22 µs]
                        change: [-65.496% -65.303% -65.086%] (p = 0.00 < 0.10)
                        Performance has improved.

post 8bit PQ results here for comparing, in short 4bit PQ is about 2x faster with the same index params:

compute_distances: 16000,l2,PQ=96,DIM=1536
                        time:   [405.11 µs 405.72 µs 406.92 µs]
                        change: [-0.2844% +0.1588% +0.6035%] (p = 0.50 > 0.10)
                        No change in performance detected.

compute_distances: 16000,cosine,PQ=96,DIM=1536
                        time:   [419.98 µs 421.05 µs 421.99 µs]
                        change: [-0.2540% +0.1098% +0.4928%] (p = 0.59 > 0.10)
                        No change in performance detected.

compute_distances: 16000,dot,PQ=96,DIM=1536
                        time:   [432.08 µs 433.63 µs 435.69 µs]
                        change: [-25.522% -25.243% -24.938%] (p = 0.00 < 0.10)
                        Performance has improved.

Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
@codecov-commenter
Copy link

codecov-commenter commented Nov 28, 2024

Codecov Report

Attention: Patch coverage is 57.07071% with 170 lines in your changes missing coverage. Please review.

Project coverage is 78.51%. Comparing base (6e84834) to head (5fc527c).

Files with missing lines Patch % Lines
rust/lance-linalg/src/simd/u8.rs 47.35% 169 Missing ⚠️
rust/lance-index/src/vector/pq/storage.rs 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3178      +/-   ##
==========================================
- Coverage   78.62%   78.51%   -0.12%     
==========================================
  Files         243      244       +1     
  Lines       82889    83213     +324     
  Branches    82889    83213     +324     
==========================================
+ Hits        65170    65331     +161     
- Misses      14933    15099     +166     
+ Partials     2786     2783       -3     
Flag Coverage Δ
unittests 78.51% <57.07%> (-0.12%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@BubbleCal BubbleCal marked this pull request as ready for review November 28, 2024 08:15
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
// let qmax = distance_table
// .chunks(NUM_CENTROIDS)
// .tuple_windows()
// .map(|(a, b)| {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete those ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

distances
}

// Quantize the distance table to u8
// returns quantized_distance_table
// used for only 4bit PQ so num_centroids must be 16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add comment about what are the returns

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -278,7 +294,7 @@ mod tests {
let pq_codes = Vec::from_iter((0..num_vectors * num_sub_vectors).map(|v| v as u8));
let pq_codes = UInt8Array::from_iter_values(pq_codes);
let transposed_codes = transpose(&pq_codes, num_vectors, num_sub_vectors);
let distances = compute_l2_distance(
let distances = compute_pq_distance(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't use dot anymore ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compute_l2_distance and compute_dot_distance are the same, so keep only one.
the diff is at building distance table

#[derive(Clone, Copy)]
pub struct u8x16(pub __m128i);

/// 16 of 32-bit `f32` values. Use 512-bit SIMD if possible.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is copied from simd/f32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

}

#[inline]
pub fn right_shift_4(self) -> Self {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this api compatible with portable_simd?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't see bit shifting operation in portable_simd

}
#[cfg(target_arch = "loongarch64")]
unsafe {
Self(lasx_xvfrsh_b(self.0, 4))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh you figured out how to use longarch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol no, no way to test it, let me remove all loongarch code for u8x16

unsafe {
Self(vandq_u8(self.0, vdupq_n_u8(mask)))
}
#[cfg(target_arch = "loongarch64")]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we always have a fallback for non simd route?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

fn reduce_min(&self) -> u8 {
#[cfg(target_arch = "x86_64")]
unsafe {
let low = _mm_and_si128(self.0, _mm_set1_epi8(0xFF_u8 as i8));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is only using sse1? Curious whether there are avx2 related coding to make this even faster.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't find a avx2 intrinsic to do this, but reduce_min is not used for now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets just delete reduce_sum and reduce_min if they are not used.

#[cfg(target_arch = "aarch64")]
unsafe {
Self(vminq_u8(self.0, rhs.0))
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets always have a fallback route

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

#[case(4, DistanceType::L2, 0.9)]
#[case(4, DistanceType::Cosine, 0.9)]
#[case(4, DistanceType::Dot, 0.8)]
#[case(4, DistanceType::L2, 0.75)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mentioned the new algorithm can have decent recall? Should we bump this up

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

let num_vectors = code.len() * 2 / num_sub_vectors;
let mut distances = vec![0.0_f32; num_vectors];
// store the distances in u32 to avoid overflow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: f32

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
debug_assert_eq!(dist_table.as_array(), origin_dist_table.as_array());

// compute next distances
let next_indices = vec_indices.right_shift_4();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just implement a Shr for u8x16? This interface looks weird.

fn shuffle(&self, indices: u8x16) -> Self {
#[cfg(target_arch = "x86_64")]
unsafe {
Self(_mm_shuffle_epi8(self.0, indices.0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I believe so,
Chose u8x16 because it fit in arm register size

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can implement u8x32 as 2 of 128bit register on arm? Just in general this can speed up x86 old cpu a lot, similar to https://github.com/lancedb/lance/blob/main/rust/lance-linalg/src/simd/f32.rs#L462

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doable, let's do this next PR? that would need to change the computation logic as well, because there are only 16 centroids in distance_table for each sub vector.

.into_iter()
.zip(distances.iter_mut())
.for_each(|(d, sum)| {
*sum += d as f32;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this reduce_sum?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, sum is distances[i]

Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
@github-actions github-actions bot added the python label Dec 5, 2024
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
@BubbleCal BubbleCal requested a review from eddyxu December 5, 2024 07:01
@BubbleCal BubbleCal merged commit 6c7b9fd into lancedb:main Dec 5, 2024
26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants