Skip to content

Commit

Permalink
Improve: Explicit type-casting
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Mar 17, 2024
1 parent bafa69f commit a272dac
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 34 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,11 @@ function(set_compiler_flags target cpp_standard target_arch)
# MVC uses numeric values:
# > 4068 for "unknown pragmas".
# > 4146 for "unary minus operator applied to unsigned type, result still unsigned".
# We also specify /utf-8 to properly UTF-8 symbols in tests.
target_compile_options(
${target}
PRIVATE
"$<$<CXX_COMPILER_ID:MSVC>:/STOP;/wd4068;/wd4146>" # For MSVC, /WX would have been sufficient
"$<$<CXX_COMPILER_ID:MSVC>:/Bt;/wd4068;/wd4146;/utf-8>" # For MSVC, /WX would have been sufficient
"$<$<CXX_COMPILER_ID:GNU>:-Wall;-Wextra;-pedantic;-Werror;-Wfatal-errors;-Wno-unknown-pragmas;-Wno-cast-function-type;-Wno-unused-function>"
"$<$<CXX_COMPILER_ID:Clang>:-Wall;-Wextra;-pedantic;-Werror;-Wfatal-errors;-Wno-unknown-pragmas>"
"$<$<CXX_COMPILER_ID:AppleClang>:-Wall;-Wextra;-pedantic;-Werror;-Wfatal-errors;-Wno-unknown-pragmas>"
Expand Down
69 changes: 38 additions & 31 deletions include/stringzilla/stringzilla.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,18 @@
#if SZ_DYNAMIC_DISPATCH
#if defined(_WIN32) || defined(__CYGWIN__)
#define SZ_DYNAMIC __declspec(dllexport)
#define SZ_EXTERNAL __declspec(dllimport)
#define SZ_PUBLIC inline static
#define SZ_INTERNAL inline static
#else
#define SZ_DYNAMIC __attribute__((visibility("default")))
#define SZ_EXTERNAL extern
#define SZ_PUBLIC __attribute__((unused)) inline static
#define SZ_INTERNAL __attribute__((always_inline)) inline static
#endif // _WIN32 || __CYGWIN__
#else
#define SZ_DYNAMIC inline static
#define SZ_EXTERNAL extern
#define SZ_PUBLIC inline static
#define SZ_INTERNAL inline static
#endif // SZ_DYNAMIC_DISPATCH
Expand Down Expand Up @@ -1330,13 +1333,13 @@ SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) {
return (((x + (x >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24;
}
#else
SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return _tzcnt_u64(x); }
SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return _lzcnt_u64(x); }
SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return __popcnt64(x); }
SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return __popcnt(x); }
SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return (int)_tzcnt_u64(x); }
SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return (int)_lzcnt_u64(x); }
SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return (int)__popcnt64(x); }
SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return (int)__popcnt(x); }
#endif
SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return _tzcnt_u32(x); }
SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return _lzcnt_u32(x); }
SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return (int)_tzcnt_u32(x); }
SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return (int)_lzcnt_u32(x); }
SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return _byteswap_uint64(val); }
SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return _byteswap_ulong(val); }
#else
Expand Down Expand Up @@ -3330,7 +3333,7 @@ SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value)

// In case of long strings, skip unaligned bytes, and then fill the rest in 64-bit chunks.
else {
sz_u64_t value64 = (sz_u64_t)(value)*0x0101010101010101ull;
sz_u64_t value64 = (sz_u64_t)(value) * 0x0101010101010101ull;
while ((sz_size_t)target & 7ull) *(target++) = value;
while (target + 8 <= end) *(sz_u64_t *)target = value64, target += 8;
while (target != end) *(target++) = value;
Expand Down Expand Up @@ -4077,28 +4080,28 @@ SZ_INTERNAL __mmask64 _sz_u64_clamp_mask_until(sz_size_t n) {
// The simplest approach to compute this if we know that `n` is blow or equal 64:
// return (1ull << n) - 1;
// A slightly more complex approach, if we don't know that `n` is under 64:
return _bzhi_u64(0xFFFFFFFFFFFFFFFF, n < 64 ? n : 64);
return _bzhi_u64(0xFFFFFFFFFFFFFFFF, n < 64 ? (sz_u32_t)n : 64);
}

SZ_INTERNAL __mmask32 _sz_u32_clamp_mask_until(sz_size_t n) {
// The simplest approach to compute this if we know that `n` is blow or equal 32:
// return (1ull << n) - 1;
// A slightly more complex approach, if we don't know that `n` is under 32:
return _bzhi_u32(0xFFFFFFFF, n < 32 ? n : 32);
return _bzhi_u32(0xFFFFFFFF, n < 32 ? (sz_u32_t)n : 32);
}

SZ_INTERNAL __mmask16 _sz_u16_clamp_mask_until(sz_size_t n) {
// The simplest approach to compute this if we know that `n` is blow or equal 16:
// return (1ull << n) - 1;
// A slightly more complex approach, if we don't know that `n` is under 16:
return _bzhi_u32(0xFFFFFFFF, n < 16 ? n : 16);
return _bzhi_u32(0xFFFFFFFF, n < 16 ? (sz_u32_t)n : 16);
}

SZ_INTERNAL __mmask64 _sz_u64_mask_until(sz_size_t n) {
// The simplest approach to compute this if we know that `n` is blow or equal 64:
// return (1ull << n) - 1;
// A slightly more complex approach, if we don't know that `n` is under 64:
return _bzhi_u64(0xFFFFFFFFFFFFFFFF, n);
return _bzhi_u64(0xFFFFFFFFFFFFFFFF, (sz_u32_t)n);
}

SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) {
Expand All @@ -4111,7 +4114,7 @@ SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr
b_vec.zmm = _mm512_loadu_epi8(b);
mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm);
if (mask_not_equal != 0) {
int first_diff = _tzcnt_u64(mask_not_equal);
unsigned __int64 first_diff = _tzcnt_u64(mask_not_equal);
char a_char = a[first_diff];
char b_char = b[first_diff];
return _sz_order_scalars(a_char, b_char);
Expand All @@ -4130,7 +4133,7 @@ SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr
// been cheaper, if we didn't have to apply `_mm256_movemask_epi8` afterwards.
mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm);
if (mask_not_equal != 0) {
int first_diff = _tzcnt_u64(mask_not_equal);
unsigned __int64 first_diff = _tzcnt_u64(mask_not_equal);
char a_char = a[first_diff];
char b_char = b[first_diff];
return _sz_order_scalars(a_char, b_char);
Expand Down Expand Up @@ -4420,8 +4423,8 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( //
for (; next_skew_diagonal_index != n; ++next_skew_diagonal_index) {
sz_size_t const next_skew_diagonal_length = next_skew_diagonal_index + 1;
for (sz_size_t i = 0; i + 2 < next_skew_diagonal_length;) {
sz_size_t remaining_length = next_skew_diagonal_length - i - 2;
sz_size_t register_length = remaining_length < 32 ? remaining_length : 32;
sz_u32_t remaining_length = (sz_u32_t)next_skew_diagonal_length - i - 2;
sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32;
sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length);
longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + i);
// Our original code addressed the shorter string `[next_skew_diagonal_index - i - 2]` for growing `i`.
Expand All @@ -4448,7 +4451,7 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( //
i += register_length;
}
// Don't forget to populate the first row and the fiest column of the Levenshtein matrix.
next_distances[0] = next_distances[next_skew_diagonal_length - 1] = next_skew_diagonal_index;
next_distances[0] = next_distances[next_skew_diagonal_length - 1] = (sz_u16_t)next_skew_diagonal_index;
// Perform a circular rotarion of those buffers, to reuse the memory.
sz_u16_t *temporary = previous_distances;
previous_distances = current_distances;
Expand All @@ -4463,8 +4466,8 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( //
for (; next_skew_diagonal_index != total_diagonals; ++next_skew_diagonal_index) {
sz_size_t const next_skew_diagonal_length = total_diagonals - next_skew_diagonal_index;
for (sz_size_t i = 0; i != next_skew_diagonal_length;) {
sz_size_t remaining_length = next_skew_diagonal_length - i;
sz_size_t register_length = remaining_length < 32 ? remaining_length : 32;
sz_u32_t remaining_length = (sz_u32_t)next_skew_diagonal_length - i;
sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32;
sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length);
longer_vec.ymms[0] =
_mm256_maskz_loadu_epi8(remaining_length_mask, longer + next_skew_diagonal_index - n + i);
Expand Down Expand Up @@ -4824,7 +4827,7 @@ SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_avx512( //

// Intialize the first row of the Levenshtein matrix with `iota`.
for (sz_size_t idx_longer = 0; idx_longer != n; ++idx_longer)
previous_distances[idx_longer] = (sz_ssize_t)idx_longer * gap;
previous_distances[idx_longer] = (sz_i32_t)idx_longer * gap;

/// Contains up to 16 consecutive characters from the longer string.
sz_u512_vec_t longer_vec;
Expand All @@ -4844,7 +4847,7 @@ SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_avx512( //

sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter;
for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) {
sz_i32_t last_in_row = current_distances[0] = (sz_ssize_t)(idx_shorter + 1) * gap;
sz_i32_t last_in_row = current_distances[0] = (sz_i32_t)(idx_shorter + 1) * gap;

// Load one row of the substitution matrix into four ZMM registers.
sz_error_cost_t const *row_subs = subs + shorter_unsigned[idx_shorter] * 256u;
Expand Down Expand Up @@ -4904,10 +4907,10 @@ SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_avx512( //
// To minimize the number of loads and stores, we can combine our substitution costs with the previous
// distances, containing the deletion costs.
{
cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32(mask, previous_distances + idx_longer);
cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer);
cost_substitution_vec.zmm = _mm512_add_epi32(
cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 0)));
cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32(mask, previous_distances + 1 + idx_longer);
cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer);
cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm);
current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm);

Expand All @@ -4930,29 +4933,31 @@ SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_avx512( //
// ... yet this approach is also quite expensive.
for (int i = 0; i != 16; ++i)
current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap);
_mm512_mask_storeu_epi32(current_distances + idx_longer + 1, mask, current_vec.zmm);
_mm512_mask_storeu_epi32(current_distances + idx_longer + 1, (__mmask16)mask, current_vec.zmm);
}

// Export the values from 16 to 31.
if (register_length > 16) {
mask = _kshiftri_mask64(mask, 16);
cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32(mask, previous_distances + idx_longer + 16);
cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 16);
cost_substitution_vec.zmm = _mm512_add_epi32(
cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 1)));
cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32(mask, previous_distances + 1 + idx_longer + 16);
cost_deletion_vec.zmm =
_mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 16);
cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm);
current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm);

// Aggregate running insertion costs within the register.
for (int i = 0; i != 16; ++i)
current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap);
_mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 16, mask, current_vec.zmm);
_mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 16, (__mmask16)mask, current_vec.zmm);
}

// Export the values from 32 to 47.
if (register_length > 32) {
mask = _kshiftri_mask64(mask, 16);
cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32(mask, previous_distances + idx_longer + 32);
cost_substitution_vec.zmm =
_mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 32);
cost_substitution_vec.zmm = _mm512_add_epi32(
cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 0)));
cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32(mask, previous_distances + 1 + idx_longer + 32);
Expand All @@ -4962,23 +4967,25 @@ SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_avx512( //
// Aggregate running insertion costs within the register.
for (int i = 0; i != 16; ++i)
current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap);
_mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 32, mask, current_vec.zmm);
_mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 32, (__mmask16)mask, current_vec.zmm);
}

// Export the values from 32 to 47.
if (register_length > 48) {
mask = _kshiftri_mask64(mask, 16);
cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32(mask, previous_distances + idx_longer + 48);
cost_substitution_vec.zmm =
_mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 48);
cost_substitution_vec.zmm = _mm512_add_epi32(
cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 1)));
cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32(mask, previous_distances + 1 + idx_longer + 48);
cost_deletion_vec.zmm =
_mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 48);
cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm);
current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm);

// Aggregate running insertion costs within the register.
for (int i = 0; i != 16; ++i)
current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap);
_mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 48, mask, current_vec.zmm);
_mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 48, (__mmask16)mask, current_vec.zmm);
}
}

Expand Down
2 changes: 1 addition & 1 deletion scripts/bench_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ void bench_permute(char const *name, strings_t &strings, permute_t &permute, alg

// Measure elapsed time
stdcc::time_point t2 = stdcc::now();
double dif = stdc::duration_cast<stdc::nanoseconds>(t2 - t1).count();
double dif = stdc::duration_cast<stdc::nanoseconds>(t2 - t1).count() * 1.0;
double milisecs = dif / (iterations * 1e6);
std::printf("Elapsed time is %.2lf miliseconds/iteration for %s.\n", milisecs, name);
}
Expand Down
3 changes: 2 additions & 1 deletion scripts/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ static void test_arithmetical_utilities() {

for (sz_u16_t number = 0; number != 256; ++number)
for (sz_u16_t divisor = 2; divisor != 256; ++divisor)
assert(sz_u8_divide(number, divisor) == (number / divisor));
assert(sz_u8_divide(static_cast<sz_u8_t>(number), static_cast<sz_u8_t>(divisor)) ==
(static_cast<sz_u8_t>(number) / static_cast<sz_u8_t>(divisor)));
}

/**
Expand Down

0 comments on commit a272dac

Please sign in to comment.