-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
perf: in-register lookup table & SIMD for 4bit PQ #3178
Conversation
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3178 +/- ##
==========================================
- Coverage 78.62% 78.51% -0.12%
==========================================
Files 243 244 +1
Lines 82889 83213 +324
Branches 82889 83213 +324
==========================================
+ Hits 65170 65331 +161
- Misses 14933 15099 +166
+ Partials 2786 2783 -3
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
// let qmax = distance_table | ||
// .chunks(NUM_CENTROIDS) | ||
// .tuple_windows() | ||
// .map(|(a, b)| { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete those ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
distances | ||
} | ||
|
||
// Quantize the distance table to u8 | ||
// returns quantized_distance_table | ||
// used for only 4bit PQ so num_centroids must be 16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add comment about what are the returns
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
@@ -278,7 +294,7 @@ mod tests { | |||
let pq_codes = Vec::from_iter((0..num_vectors * num_sub_vectors).map(|v| v as u8)); | |||
let pq_codes = UInt8Array::from_iter_values(pq_codes); | |||
let transposed_codes = transpose(&pq_codes, num_vectors, num_sub_vectors); | |||
let distances = compute_l2_distance( | |||
let distances = compute_pq_distance( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't use dot anymore ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
compute_l2_distance and compute_dot_distance are the same, so keep only one.
the diff is at building distance table
rust/lance-linalg/src/simd/u8.rs
Outdated
#[derive(Clone, Copy)] | ||
pub struct u8x16(pub __m128i); | ||
|
||
/// 16 of 32-bit `f32` values. Use 512-bit SIMD if possible. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is copied from simd/f32
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
rust/lance-linalg/src/simd/u8.rs
Outdated
} | ||
|
||
#[inline] | ||
pub fn right_shift_4(self) -> Self { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this api compatible with portable_simd
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didn't see bit shifting operation in portable_simd
rust/lance-linalg/src/simd/u8.rs
Outdated
} | ||
#[cfg(target_arch = "loongarch64")] | ||
unsafe { | ||
Self(lasx_xvfrsh_b(self.0, 4)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
huh you figured out how to use longarch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol no, no way to test it, let me remove all loongarch code for u8x16
rust/lance-linalg/src/simd/u8.rs
Outdated
unsafe { | ||
Self(vandq_u8(self.0, vdupq_n_u8(mask))) | ||
} | ||
#[cfg(target_arch = "loongarch64")] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we always have a fallback for non simd route?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
fn reduce_min(&self) -> u8 { | ||
#[cfg(target_arch = "x86_64")] | ||
unsafe { | ||
let low = _mm_and_si128(self.0, _mm_set1_epi8(0xFF_u8 as i8)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is only using sse1? Curious whether there are avx2
related coding to make this even faster.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didn't find a avx2
intrinsic to do this, but reduce_min
is not used for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets just delete reduce_sum
and reduce_min
if they are not used.
#[cfg(target_arch = "aarch64")] | ||
unsafe { | ||
Self(vminq_u8(self.0, rhs.0)) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets always have a fallback route
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
#[case(4, DistanceType::L2, 0.9)] | ||
#[case(4, DistanceType::Cosine, 0.9)] | ||
#[case(4, DistanceType::Dot, 0.8)] | ||
#[case(4, DistanceType::L2, 0.75)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mentioned the new algorithm can have decent recall? Should we bump this up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
let num_vectors = code.len() * 2 / num_sub_vectors; | ||
let mut distances = vec![0.0_f32; num_vectors]; | ||
// store the distances in u32 to avoid overflow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: f32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
debug_assert_eq!(dist_table.as_array(), origin_dist_table.as_array()); | ||
|
||
// compute next distances | ||
let next_indices = vec_indices.right_shift_4(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we just implement a Shr
for u8x16? This interface looks weird.
fn shuffle(&self, indices: u8x16) -> Self { | ||
#[cfg(target_arch = "x86_64")] | ||
unsafe { | ||
Self(_mm_shuffle_epi8(self.0, indices.0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be faster if we can use https://doc.rust-lang.org/beta/core/arch/x86_64/fn._mm256_shuffle_epi8.html (u8x32)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I believe so,
Chose u8x16 because it fit in arm register size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can implement u8x32
as 2 of 128bit register on arm? Just in general this can speed up x86 old cpu a lot, similar to https://github.com/lancedb/lance/blob/main/rust/lance-linalg/src/simd/f32.rs#L462
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doable, let's do this next PR? that would need to change the computation logic as well, because there are only 16 centroids in distance_table for each sub vector.
.into_iter() | ||
.zip(distances.iter_mut()) | ||
.for_each(|(d, sum)| { | ||
*sum += d as f32; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this reduce_sum
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, sum
is distances[i]
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
Signed-off-by: BubbleCal <[email protected]>
4bit PQ is 3x faster than before:
post 8bit PQ results here for comparing, in short 4bit PQ is about 2x faster with the same index params: