diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index df8065212193..bcf401bb07be 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -495,6 +495,31 @@ void CodeGen_X86::visit(const Call *op) { return; } + // A 16-bit mul-shift-right of less than 16 can sometimes be rounded up to a + // full 16 to use pmulh(u)w by left-shifting one of the operands. This is + // handled here instead of in the lowering of mul_shift_right because it's + // unlikely to be a good idea on platforms other than x86, as it adds an + // extra shift in the fully-lowered case. + if ((op->type.element_of() == UInt(16) || + op->type.element_of() == Int(16)) && + op->is_intrinsic(Call::mul_shift_right)) { + internal_assert(op->args.size() == 3); + const uint64_t *shift = as_const_uint(op->args[2]); + if (shift && *shift < 16 && *shift >= 8) { + Type narrow = op->type.with_bits(8); + Expr narrow_a = lossless_cast(narrow, op->args[0]); + Expr narrow_b = narrow_a.defined() ? Expr() : lossless_cast(narrow, op->args[1]); + int shift_left = 16 - (int)(*shift); + if (narrow_a.defined()) { + codegen(mul_shift_right(op->args[0] << shift_left, op->args[1], 16)); + return; + } else if (narrow_b.defined()) { + codegen(mul_shift_right(op->args[0], op->args[1] << shift_left, 16)); + return; + } + } + } + struct Pattern { string intrin; Expr pattern; diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 1b91da64f587..16cd2432a98c 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -384,6 +384,7 @@ class FindIntrinsics : public IRMutator { auto is_x_same_int = op->type.is_int() && is_int(x, bits); auto is_x_same_uint = op->type.is_uint() && is_uint(x, bits); auto is_x_same_int_or_uint = is_x_same_int || is_x_same_uint; + auto is_y_narrow_uint = op->type.is_uint() && is_uint(y, bits / 2); if ( // Saturating patterns rewrite(max(min(widening_add(x, y), upper), lower), @@ -493,6 +494,25 @@ class FindIntrinsics : public IRMutator { rounding_mul_shift_right(x, y, cast(unsigned_type, c0)), is_x_same_int_or_uint && c0 >= bits) || + // We can also match on smaller shifts if one of the args is + // narrow. We don't do this for signed (yet), because the + // saturation issue is tricky. + rewrite(shift_right(widening_mul(x, cast(op->type, y)), c0), + mul_shift_right(x, cast(op->type, y), cast(unsigned_type, c0)), + is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) || + + rewrite(rounding_shift_right(widening_mul(x, cast(op->type, y)), c0), + rounding_mul_shift_right(x, cast(op->type, y), cast(unsigned_type, c0)), + is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) || + + rewrite(shift_right(widening_mul(cast(op->type, y), x), c0), + mul_shift_right(cast(op->type, y), x, cast(unsigned_type, c0)), + is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) || + + rewrite(rounding_shift_right(widening_mul(cast(op->type, y), x), c0), + rounding_mul_shift_right(cast(op->type, y), x, cast(unsigned_type, c0)), + is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) || + // Halving subtract patterns rewrite(shift_right(cast(op_type_wide, widening_sub(x, y)), 1), halving_sub(x, y), diff --git a/test/correctness/intrinsics.cpp b/test/correctness/intrinsics.cpp index bc5764f5d7f8..f978a7956a14 100644 --- a/test/correctness/intrinsics.cpp +++ b/test/correctness/intrinsics.cpp @@ -308,6 +308,9 @@ int main(int argc, char **argv) { check(i8(i16(i8x) * i16(i8y) >> 8), mul_shift_right(i8x, i8y, 8)); check(u8(u16(u8x) * u16(u8y) >> 8), mul_shift_right(u8x, u8y, 8)); + // Multiplication of mixed-width integers + check(u16(u32(u16x) * u32(u8y) >> 8), mul_shift_right(u16x, u16(u8y), 8)); + check(i8_sat(rounding_shift_right(i16(i8x) * i16(i8y), 7)), rounding_mul_shift_right(i8x, i8y, 7)); check(i8(min(rounding_shift_right(i16(i8x) * i16(i8y), 7), 127)), rounding_mul_shift_right(i8x, i8y, 7)); check(i8_sat(rounding_shift_right(i16(i8x) * i16(i8y), 8)), rounding_mul_shift_right(i8x, i8y, 8)); diff --git a/test/correctness/simd_op_check.cpp b/test/correctness/simd_op_check.cpp index f82df01147ac..fa51ac99d917 100644 --- a/test/correctness/simd_op_check.cpp +++ b/test/correctness/simd_op_check.cpp @@ -117,6 +117,13 @@ class SimdOpCheck : public SimdOpCheckTest { check("pmulhuw", 4 * w, i16_1 / 15); + // Shifts by amounts other than 16 can also use this instruction, by + // preshifting an arg (when there are bits of headroom), or + // postshifting the result. + check("pmulhuw", 4 * w, u16((u32(u16_1) * u32(u8_2)) >> 13)); + check("pmulhw", 4 * w, i16((i32(i16_1) * i32(i16_2)) >> 17)); + check("pmulhuw", 4 * w, u16((u32(u16_1) * u32(u16_2)) >> 18)); + if (w > 1) { // LLVM does a lousy job at the comparisons for 64-bit types check("pcmp*b", 8 * w, select(u8_1 == u8_2, u8(1), u8(2))); check("pcmp*b", 8 * w, select(u8_1 > u8_2, u8(1), u8(2)));