Skip to content

Commit

Permalink
[InstCombine][X86] Add constant folding for PMULH/PMULHU/PMULHRS intr…
Browse files Browse the repository at this point in the history
…insics
  • Loading branch information
RKSimon authored and lravenclaw committed Jul 3, 2024
1 parent ae50b80 commit b858a73
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 23 deletions.
47 changes: 42 additions & 5 deletions llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,13 +503,15 @@ static Value *simplifyX86pack(IntrinsicInst &II,
}

static Value *simplifyX86pmulh(IntrinsicInst &II,
InstCombiner::BuilderTy &Builder) {
InstCombiner::BuilderTy &Builder, bool IsSigned,
bool IsRounding) {
Value *Arg0 = II.getArgOperand(0);
Value *Arg1 = II.getArgOperand(1);
auto *ResTy = cast<FixedVectorType>(II.getType());
[[maybe_unused]] auto *ArgTy = cast<FixedVectorType>(Arg0->getType());
auto *ArgTy = cast<FixedVectorType>(Arg0->getType());
assert(ArgTy == ResTy && ResTy->getScalarSizeInBits() == 16 &&
"Unexpected PMULH types");
assert((!IsRounding || IsSigned) && "PMULHRS instruction must be signed");

// Multiply by undef -> zero (NOT undef!) as other arg could still be zero.
if (isa<UndefValue>(Arg0) || isa<UndefValue>(Arg1))
Expand All @@ -519,8 +521,33 @@ static Value *simplifyX86pmulh(IntrinsicInst &II,
if (isa<ConstantAggregateZero>(Arg0) || isa<ConstantAggregateZero>(Arg1))
return ConstantAggregateZero::get(ResTy);

// TODO: Constant folding.
return nullptr;
// Constant folding.
if (!isa<Constant>(Arg0) || !isa<Constant>(Arg1))
return nullptr;

// Extend to twice the width and multiply.
auto Cast =
IsSigned ? Instruction::CastOps::SExt : Instruction::CastOps::ZExt;
auto *ExtTy = FixedVectorType::getExtendedElementVectorType(ArgTy);
Value *LHS = Builder.CreateCast(Cast, Arg0, ExtTy);
Value *RHS = Builder.CreateCast(Cast, Arg1, ExtTy);
Value *Mul = Builder.CreateMul(LHS, RHS);

if (IsRounding) {
// PMULHRSW: truncate to vXi18 of the most significant bits, add one and
// extract bits[16:1].
auto *RndEltTy = IntegerType::get(ExtTy->getContext(), 18);
auto *RndTy = FixedVectorType::get(RndEltTy, ExtTy);
Mul = Builder.CreateLShr(Mul, 14);
Mul = Builder.CreateTrunc(Mul, RndTy);
Mul = Builder.CreateAdd(Mul, ConstantInt::get(RndTy, 1));
Mul = Builder.CreateLShr(Mul, 1);
} else {
// PMULH/PMULHU: extract the vXi16 most significant bits.
Mul = Builder.CreateLShr(Mul, 16);
}

return Builder.CreateTrunc(Mul, ResTy);
}

static Value *simplifyX86pmadd(IntrinsicInst &II,
Expand Down Expand Up @@ -2592,13 +2619,23 @@ X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
case Intrinsic::x86_sse2_pmulh_w:
case Intrinsic::x86_avx2_pmulh_w:
case Intrinsic::x86_avx512_pmulh_w_512:
if (Value *V = simplifyX86pmulh(II, IC.Builder, true, false)) {
return IC.replaceInstUsesWith(II, V);
}
break;

case Intrinsic::x86_sse2_pmulhu_w:
case Intrinsic::x86_avx2_pmulhu_w:
case Intrinsic::x86_avx512_pmulhu_w_512:
if (Value *V = simplifyX86pmulh(II, IC.Builder, false, false)) {
return IC.replaceInstUsesWith(II, V);
}
break;

case Intrinsic::x86_ssse3_pmul_hr_sw_128:
case Intrinsic::x86_avx2_pmul_hr_sw:
case Intrinsic::x86_avx512_pmul_hr_sw_512:
if (Value *V = simplifyX86pmulh(II, IC.Builder)) {
if (Value *V = simplifyX86pmulh(II, IC.Builder, true, true)) {
return IC.replaceInstUsesWith(II, V);
}
break;
Expand Down
9 changes: 3 additions & 6 deletions llvm/test/Transforms/InstCombine/X86/x86-pmulh.ll
Original file line number Diff line number Diff line change
Expand Up @@ -111,26 +111,23 @@ define <32 x i16> @zero_pmulh_512_commute(<32 x i16> %a0) {

define <8 x i16> @fold_pmulh_128() {
; CHECK-LABEL: @fold_pmulh_128(
; CHECK-NEXT: [[TMP1:%.*]] = call <8 x i16> @llvm.x86.sse2.pmulh.w(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -32768, i16 32765, i16 -9, i16 -11, i16 -32763, i16 32761>)
; CHECK-NEXT: ret <8 x i16> [[TMP1]]
; CHECK-NEXT: ret <8 x i16> <i16 0, i16 0, i16 -2, i16 -2, i16 0, i16 -1, i16 -4, i16 -4>
;
%1 = call <8 x i16> @llvm.x86.sse2.pmulh.w(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -32768, i16 32765, i16 -9, i16 -11, i16 -32763, i16 32761>)
ret <8 x i16> %1
}

define <16 x i16> @fold_pmulh_256() {
; CHECK-LABEL: @fold_pmulh_256(
; CHECK-NEXT: [[TMP1:%.*]] = call <16 x i16> @llvm.x86.avx2.pmulh.w(<16 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>, <16 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>)
; CHECK-NEXT: ret <16 x i16> [[TMP1]]
; CHECK-NEXT: ret <16 x i16> <i16 0, i16 -1, i16 -1, i16 1, i16 0, i16 0, i16 -3, i16 3, i16 -1, i16 -1, i16 4, i16 5, i16 -1, i16 -1, i16 6, i16 -8>
;
%1 = call <16 x i16> @llvm.x86.avx2.pmulh.w(<16 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>, <16 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>)
ret <16 x i16> %1
}

define <32 x i16> @fold_pmulh_512() {
; CHECK-LABEL: @fold_pmulh_512(
; CHECK-NEXT: [[TMP1:%.*]] = call <32 x i16> @llvm.x86.avx512.pmulh.w.512(<32 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15, i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>, <32 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756, i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>)
; CHECK-NEXT: ret <32 x i16> [[TMP1]]
; CHECK-NEXT: ret <32 x i16> <i16 0, i16 -1, i16 -1, i16 1, i16 0, i16 0, i16 -3, i16 3, i16 -1, i16 -1, i16 4, i16 5, i16 -1, i16 -1, i16 6, i16 -8, i16 0, i16 -1, i16 -1, i16 1, i16 0, i16 0, i16 -3, i16 3, i16 -1, i16 -1, i16 4, i16 5, i16 -1, i16 -1, i16 6, i16 -8>
;
%1 = call <32 x i16> @llvm.x86.avx512.pmulh.w.512(<32 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15, i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>, <32 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756, i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>)
ret <32 x i16> %1
Expand Down
9 changes: 3 additions & 6 deletions llvm/test/Transforms/InstCombine/X86/x86-pmulhrs.ll
Original file line number Diff line number Diff line change
Expand Up @@ -111,26 +111,23 @@ define <32 x i16> @zero_pmulh_512_commute(<32 x i16> %a0) {

define <8 x i16> @fold_pmulh_128() {
; CHECK-LABEL: @fold_pmulh_128(
; CHECK-NEXT: [[TMP1:%.*]] = call <8 x i16> @llvm.x86.ssse3.pmul.hr.sw.128(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -32768, i16 32765, i16 -9, i16 -11, i16 -32763, i16 32761>)
; CHECK-NEXT: ret <8 x i16> [[TMP1]]
; CHECK-NEXT: ret <8 x i16> <i16 0, i16 0, i16 -3, i16 -4, i16 0, i16 0, i16 -7, i16 -8>
;
%1 = call <8 x i16> @llvm.x86.ssse3.pmul.hr.sw.128(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -32768, i16 32765, i16 -9, i16 -11, i16 -32763, i16 32761>)
ret <8 x i16> %1
}

define <16 x i16> @fold_pmulh_256() {
; CHECK-LABEL: @fold_pmulh_256(
; CHECK-NEXT: [[TMP1:%.*]] = call <16 x i16> @llvm.x86.avx2.pmul.hr.sw(<16 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>, <16 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>)
; CHECK-NEXT: ret <16 x i16> [[TMP1]]
; CHECK-NEXT: ret <16 x i16> <i16 0, i16 0, i16 -2, i16 3, i16 0, i16 0, i16 -6, i16 7, i16 0, i16 0, i16 10, i16 11, i16 0, i16 0, i16 14, i16 -15>
;
%1 = call <16 x i16> @llvm.x86.avx2.pmul.hr.sw(<16 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>, <16 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>)
ret <16 x i16> %1
}

define <32 x i16> @fold_pmulh_512() {
; CHECK-LABEL: @fold_pmulh_512(
; CHECK-NEXT: [[TMP1:%.*]] = call <32 x i16> @llvm.x86.avx512.pmul.hr.sw.512(<32 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15, i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>, <32 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756, i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>)
; CHECK-NEXT: ret <32 x i16> [[TMP1]]
; CHECK-NEXT: ret <32 x i16> <i16 0, i16 0, i16 -2, i16 3, i16 0, i16 0, i16 -6, i16 7, i16 0, i16 0, i16 10, i16 11, i16 0, i16 0, i16 14, i16 -15, i16 0, i16 0, i16 -2, i16 3, i16 0, i16 0, i16 -6, i16 7, i16 0, i16 0, i16 10, i16 11, i16 0, i16 0, i16 14, i16 -15>
;
%1 = call <32 x i16> @llvm.x86.avx512.pmul.hr.sw.512(<32 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15, i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>, <32 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756, i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>)
ret <32 x i16> %1
Expand Down
9 changes: 3 additions & 6 deletions llvm/test/Transforms/InstCombine/X86/x86-pmulhu.ll
Original file line number Diff line number Diff line change
Expand Up @@ -111,26 +111,23 @@ define <32 x i16> @zero_pmulhu_512_commute(<32 x i16> %a0) {

define <8 x i16> @fold_pmulhu_128() {
; CHECK-LABEL: @fold_pmulhu_128(
; CHECK-NEXT: [[TMP1:%.*]] = call <8 x i16> @llvm.x86.sse2.pmulhu.w(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -32768, i16 32765, i16 -9, i16 -11, i16 -32763, i16 32761>)
; CHECK-NEXT: ret <8 x i16> [[TMP1]]
; CHECK-NEXT: ret <8 x i16> <i16 -6, i16 0, i16 1, i16 32763, i16 -14, i16 5, i16 3, i16 32757>
;
%1 = call <8 x i16> @llvm.x86.sse2.pmulhu.w(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -32768, i16 32765, i16 -9, i16 -11, i16 -32763, i16 32761>)
ret <8 x i16> %1
}

define <16 x i16> @fold_pmulhu_256() {
; CHECK-LABEL: @fold_pmulhu_256(
; CHECK-NEXT: [[TMP1:%.*]] = call <16 x i16> @llvm.x86.avx2.pmulhu.w(<16 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>, <16 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>)
; CHECK-NEXT: ret <16 x i16> [[TMP1]]
; CHECK-NEXT: ret <16 x i16> <i16 0, i16 6, i16 1, i16 1, i16 -13, i16 -16, i16 3, i16 3, i16 12, i16 8, i16 -32766, i16 5, i16 16, i16 12, i16 -32764, i16 32748>
;
%1 = call <16 x i16> @llvm.x86.avx2.pmulhu.w(<16 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>, <16 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>)
ret <16 x i16> %1
}

define <32 x i16> @fold_pmulhu_512() {
; CHECK-LABEL: @fold_pmulhu_512(
; CHECK-NEXT: [[TMP1:%.*]] = call <32 x i16> @llvm.x86.avx512.pmulhu.w.512(<32 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15, i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>, <32 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756, i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>)
; CHECK-NEXT: ret <32 x i16> [[TMP1]]
; CHECK-NEXT: ret <32 x i16> <i16 0, i16 6, i16 1, i16 1, i16 -13, i16 -16, i16 3, i16 3, i16 12, i16 8, i16 -32766, i16 5, i16 16, i16 12, i16 -32764, i16 32748, i16 0, i16 6, i16 1, i16 1, i16 -13, i16 -16, i16 3, i16 3, i16 12, i16 8, i16 -32766, i16 5, i16 16, i16 12, i16 -32764, i16 32748>
;
%1 = call <32 x i16> @llvm.x86.avx512.pmulhu.w.512(<32 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15, i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>, <32 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756, i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>)
ret <32 x i16> %1
Expand Down

0 comments on commit b858a73

Please sign in to comment.