Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle mixed-width args to mul-shift-right #6526

Merged
merged 3 commits into from
Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,31 @@ void CodeGen_X86::visit(const Call *op) {
return;
}

// A 16-bit mul-shift-right of less than 16 can sometimes be rounded up to a
// full 16 to use pmulh(u)w by left-shifting one of the operands. This is
// handled here instead of in the lowering of mul_shift_right because it's
// unlikely to be a good idea on platforms other than x86, as it adds an
// extra shift in the fully-lowered case.
if ((op->type.element_of() == UInt(16) ||
op->type.element_of() == Int(16)) &&
op->is_intrinsic(Call::mul_shift_right)) {
internal_assert(op->args.size() == 3);
const uint64_t *shift = as_const_uint(op->args[2]);
if (shift && *shift < 16 && *shift >= 8) {
Type narrow = op->type.with_bits(8);
Expr narrow_a = lossless_cast(narrow, op->args[0]);
Expr narrow_b = narrow_a.defined() ? Expr() : lossless_cast(narrow, op->args[1]);
int shift_left = 16 - (int)(*shift);
if (narrow_a.defined()) {
codegen(mul_shift_right(op->args[0] << shift_left, op->args[1], 16));
return;
} else if (narrow_b.defined()) {
codegen(mul_shift_right(op->args[0], op->args[1] << shift_left, 16));
return;
}
}
}

struct Pattern {
string intrin;
Expr pattern;
Expand Down
20 changes: 20 additions & 0 deletions src/FindIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ class FindIntrinsics : public IRMutator {
auto is_x_same_int = op->type.is_int() && is_int(x, bits);
auto is_x_same_uint = op->type.is_uint() && is_uint(x, bits);
auto is_x_same_int_or_uint = is_x_same_int || is_x_same_uint;
auto is_y_narrow_uint = op->type.is_uint() && is_uint(y, bits / 2);
if (
// Saturating patterns
rewrite(max(min(widening_add(x, y), upper), lower),
Expand Down Expand Up @@ -493,6 +494,25 @@ class FindIntrinsics : public IRMutator {
rounding_mul_shift_right(x, y, cast(unsigned_type, c0)),
is_x_same_int_or_uint && c0 >= bits) ||

// We can also match on smaller shifts if one of the args is
// narrow. We don't do this for signed (yet), because the
// saturation issue is tricky.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we think we will ever do so, tricky or not? (If so, tracking issue, please)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure. It depends on if it comes up in practice. It's part of the unbounded frontier of ways we could make instruction selection more clever. I don't want to add a tracking issue for each possible new way to use an instruction.

rewrite(shift_right(widening_mul(x, cast(op->type, y)), c0),
mul_shift_right(x, cast(op->type, y), cast(unsigned_type, c0)),
is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) ||

rewrite(rounding_shift_right(widening_mul(x, cast(op->type, y)), c0),
rounding_mul_shift_right(x, cast(op->type, y), cast(unsigned_type, c0)),
is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) ||

rewrite(shift_right(widening_mul(cast(op->type, y), x), c0),
mul_shift_right(cast(op->type, y), x, cast(unsigned_type, c0)),
is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) ||

rewrite(rounding_shift_right(widening_mul(cast(op->type, y), x), c0),
rounding_mul_shift_right(cast(op->type, y), x, cast(unsigned_type, c0)),
is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) ||

// Halving subtract patterns
rewrite(shift_right(cast(op_type_wide, widening_sub(x, y)), 1),
halving_sub(x, y),
Expand Down
3 changes: 3 additions & 0 deletions test/correctness/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,9 @@ int main(int argc, char **argv) {
check(i8(i16(i8x) * i16(i8y) >> 8), mul_shift_right(i8x, i8y, 8));
check(u8(u16(u8x) * u16(u8y) >> 8), mul_shift_right(u8x, u8y, 8));

// Multiplication of mixed-width integers
check(u16(u32(u16x) * u32(u8y) >> 8), mul_shift_right(u16x, u16(u8y), 8));

check(i8_sat(rounding_shift_right(i16(i8x) * i16(i8y), 7)), rounding_mul_shift_right(i8x, i8y, 7));
check(i8(min(rounding_shift_right(i16(i8x) * i16(i8y), 7), 127)), rounding_mul_shift_right(i8x, i8y, 7));
check(i8_sat(rounding_shift_right(i16(i8x) * i16(i8y), 8)), rounding_mul_shift_right(i8x, i8y, 8));
Expand Down
7 changes: 7 additions & 0 deletions test/correctness/simd_op_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ class SimdOpCheck : public SimdOpCheckTest {

check("pmulhuw", 4 * w, i16_1 / 15);

// Shifts by amounts other than 16 can also use this instruction, by
// preshifting an arg (when there are bits of headroom), or
// postshifting the result.
check("pmulhuw", 4 * w, u16((u32(u16_1) * u32(u8_2)) >> 13));
check("pmulhw", 4 * w, i16((i32(i16_1) * i32(i16_2)) >> 17));
check("pmulhuw", 4 * w, u16((u32(u16_1) * u32(u16_2)) >> 18));

if (w > 1) { // LLVM does a lousy job at the comparisons for 64-bit types
check("pcmp*b", 8 * w, select(u8_1 == u8_2, u8(1), u8(2)));
check("pcmp*b", 8 * w, select(u8_1 > u8_2, u8(1), u8(2)));
Expand Down