From 211bcf67aadb1175af382f55403ae759177281c7 Mon Sep 17 00:00:00 2001 From: Chinmay Deshpande Date: Fri, 10 Jan 2025 09:05:41 +0530 Subject: [PATCH] [AMDGPU] Implement IR variant of isFMAFasterThanFMulAndFAdd (#121465) --- llvm/lib/Target/AMDGPU/SIISelLowering.cpp | 56 +++++++++ llvm/lib/Target/AMDGPU/SIISelLowering.h | 3 + .../CodeGen/AMDGPU/prevent-fmul-hoist-ir.ll | 108 ++++++++++++------ 3 files changed, 132 insertions(+), 35 deletions(-) diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index 992f7ed99d3b..e057c665e39d 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -5732,6 +5732,35 @@ bool SITargetLowering::isFMAFasterThanFMulAndFAdd(const MachineFunction &MF, return false; } +// Refer to comments added to the MIR variant of isFMAFasterThanFMulAndFAdd for +// specific details. +bool SITargetLowering::isFMAFasterThanFMulAndFAdd(const Function &F, + Type *Ty) const { + switch (Ty->getScalarSizeInBits()) { + case 16: { + SIModeRegisterDefaults Mode = SIModeRegisterDefaults(F, *Subtarget); + return Subtarget->has16BitInsts() && + Mode.FP64FP16Denormals != DenormalMode::getPreserveSign(); + } + case 32: { + if (!Subtarget->hasMadMacF32Insts()) + return Subtarget->hasFastFMAF32(); + + SIModeRegisterDefaults Mode = SIModeRegisterDefaults(F, *Subtarget); + if (Mode.FP32Denormals != DenormalMode::getPreserveSign()) + return Subtarget->hasFastFMAF32() || Subtarget->hasDLInsts(); + + return Subtarget->hasFastFMAF32() && Subtarget->hasDLInsts(); + } + case 64: + return true; + default: + break; + } + + return false; +} + bool SITargetLowering::isFMADLegal(const MachineInstr &MI, LLT Ty) const { if (!Ty.isScalar()) return false; @@ -16992,6 +17021,33 @@ bool SITargetLowering::checkForPhysRegDependency( return false; } +/// Check if it is profitable to hoist instruction in then/else to if. +bool SITargetLowering::isProfitableToHoist(Instruction *I) const { + if (!I->hasOneUse()) + return true; + + Instruction *User = I->user_back(); + // TODO: Add more patterns that are not profitable to hoist and + // handle modifiers such as fabs and fneg + switch (I->getOpcode()) { + case Instruction::FMul: { + if (User->getOpcode() != Instruction::FSub && + User->getOpcode() != Instruction::FAdd) + return true; + + const TargetOptions &Options = getTargetMachine().Options; + + return ((!I->hasAllowContract() || !User->hasAllowContract()) && + Options.AllowFPOpFusion != FPOpFusion::Fast && + !Options.UnsafeFPMath) || + !isFMAFasterThanFMulAndFAdd(*I->getFunction(), User->getType()); + } + default: + return true; + } + return true; +} + void SITargetLowering::emitExpandAtomicAddrSpacePredicate( Instruction *AI) const { // Given: atomicrmw fadd ptr %addr, float %val ordering diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.h b/llvm/lib/Target/AMDGPU/SIISelLowering.h index 299c8f5f7392..27960a094092 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.h +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.h @@ -459,6 +459,7 @@ class SITargetLowering final : public AMDGPUTargetLowering { EVT VT) const override; bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF, const LLT Ty) const override; + bool isFMAFasterThanFMulAndFAdd(const Function &F, Type *Ty) const override; bool isFMADLegal(const SelectionDAG &DAG, const SDNode *N) const override; bool isFMADLegal(const MachineInstr &MI, const LLT Ty) const override; @@ -538,6 +539,8 @@ class SITargetLowering final : public AMDGPUTargetLowering { const TargetInstrInfo *TII, unsigned &PhysReg, int &Cost) const override; + bool isProfitableToHoist(Instruction *I) const override; + bool isKnownNeverNaNForTargetNode(SDValue Op, const SelectionDAG &DAG, bool SNaN = false, diff --git a/llvm/test/CodeGen/AMDGPU/prevent-fmul-hoist-ir.ll b/llvm/test/CodeGen/AMDGPU/prevent-fmul-hoist-ir.ll index ef3e04c0e996..c68cd8254091 100644 --- a/llvm/test/CodeGen/AMDGPU/prevent-fmul-hoist-ir.ll +++ b/llvm/test/CodeGen/AMDGPU/prevent-fmul-hoist-ir.ll @@ -11,16 +11,17 @@ define double @is_profitable_f64_contract(ptr dereferenceable(8) %ptr_x, ptr der ; GFX-NEXT: [[CMP:%.*]] = fcmp oeq double [[Y]], 0.000000e+00 ; GFX-NEXT: [[X:%.*]] = load double, ptr [[PTR_X]], align 8 ; GFX-NEXT: [[A_1:%.*]] = load double, ptr [[PTR_A]], align 8 -; GFX-NEXT: [[MUL:%.*]] = fmul contract double [[X]], [[A_1]] ; GFX-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]] ; GFX: [[COMMON_RET:.*]]: ; GFX-NEXT: [[COMMON_RET_OP:%.*]] = phi double [ [[ADD:%.*]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ] ; GFX-NEXT: ret double [[COMMON_RET_OP]] ; GFX: [[IF_THEN]]: +; GFX-NEXT: [[MUL:%.*]] = fmul contract double [[X]], [[A_1]] ; GFX-NEXT: [[ADD]] = fadd contract double 1.000000e+00, [[MUL]] ; GFX-NEXT: br label %[[COMMON_RET]] ; GFX: [[IF_ELSE]]: -; GFX-NEXT: [[SUB]] = fsub contract double [[MUL]], [[Y]] +; GFX-NEXT: [[MUL1:%.*]] = fmul contract double [[X]], [[A_1]] +; GFX-NEXT: [[SUB]] = fsub contract double [[MUL1]], [[Y]] ; GFX-NEXT: br label %[[COMMON_RET]] ; entry: @@ -93,16 +94,17 @@ define float @is_profitable_f32(ptr dereferenceable(8) %ptr_x, ptr dereferenceab ; GFX-NEXT: [[CMP:%.*]] = fcmp oeq float [[Y]], 0.000000e+00 ; GFX-NEXT: [[X:%.*]] = load float, ptr [[PTR_X]], align 8 ; GFX-NEXT: [[A_1:%.*]] = load float, ptr [[PTR_A]], align 8 -; GFX-NEXT: [[MUL:%.*]] = fmul contract float [[X]], [[A_1]] ; GFX-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]] ; GFX: [[COMMON_RET:.*]]: -; GFX-NEXT: [[COMMON_RET_OP:%.*]] = phi float [ [[MUL]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ] +; GFX-NEXT: [[COMMON_RET_OP:%.*]] = phi float [ [[MUL:%.*]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ] ; GFX-NEXT: ret float [[COMMON_RET_OP]] ; GFX: [[IF_THEN]]: +; GFX-NEXT: [[MUL]] = fmul contract float [[X]], [[A_1]] ; GFX-NEXT: [[ADD:%.*]] = fadd contract float 1.000000e+00, [[MUL]] ; GFX-NEXT: br label %[[COMMON_RET]] ; GFX: [[IF_ELSE]]: -; GFX-NEXT: [[SUB]] = fsub contract float [[MUL]], [[Y]] +; GFX-NEXT: [[MUL1:%.*]] = fmul contract float [[X]], [[A_1]] +; GFX-NEXT: [[SUB]] = fsub contract float [[MUL1]], [[Y]] ; GFX-NEXT: br label %[[COMMON_RET]] ; entry: @@ -111,7 +113,6 @@ entry: %x = load float, ptr %ptr_x, align 8 br i1 %cmp, label %if.then, label %if.else - if.then: ; preds = %entry %a_1 = load float, ptr %ptr_a, align 8 %mul = fmul contract float %x, %a_1 @@ -172,16 +173,17 @@ define half @is_profitable_f16_ieee(ptr dereferenceable(8) %ptr_x, ptr dereferen ; GFX-NEXT: [[CMP:%.*]] = fcmp oeq half [[Y]], 0xH0000 ; GFX-NEXT: [[X:%.*]] = load half, ptr [[PTR_X]], align 8 ; GFX-NEXT: [[A_1:%.*]] = load half, ptr [[PTR_A]], align 8 -; GFX-NEXT: [[MUL:%.*]] = fmul contract half [[X]], [[A_1]] ; GFX-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]] ; GFX: [[COMMON_RET:.*]]: -; GFX-NEXT: [[COMMON_RET_OP:%.*]] = phi half [ [[MUL]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ] +; GFX-NEXT: [[COMMON_RET_OP:%.*]] = phi half [ [[MUL:%.*]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ] ; GFX-NEXT: ret half [[COMMON_RET_OP]] ; GFX: [[IF_THEN]]: +; GFX-NEXT: [[MUL]] = fmul contract half [[X]], [[A_1]] ; GFX-NEXT: [[ADD:%.*]] = fadd contract half [[Y]], [[MUL]] ; GFX-NEXT: br label %[[COMMON_RET]] ; GFX: [[IF_ELSE]]: -; GFX-NEXT: [[SUB]] = fsub contract half [[MUL]], [[Y]] +; GFX-NEXT: [[MUL1:%.*]] = fmul contract half [[X]], [[A_1]] +; GFX-NEXT: [[SUB]] = fsub contract half [[MUL1]], [[Y]] ; GFX-NEXT: br label %[[COMMON_RET]] ; entry: @@ -250,16 +252,17 @@ define bfloat @is_profitable_bfloat_ieee(ptr dereferenceable(8) %ptr_x, ptr dere ; GFX-NEXT: [[CMP:%.*]] = fcmp oeq bfloat [[Y]], 0xR0000 ; GFX-NEXT: [[X:%.*]] = load bfloat, ptr [[PTR_X]], align 8 ; GFX-NEXT: [[A_1:%.*]] = load bfloat, ptr [[PTR_A]], align 8 -; GFX-NEXT: [[MUL:%.*]] = fmul contract bfloat [[X]], [[A_1]] ; GFX-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]] ; GFX: [[COMMON_RET:.*]]: -; GFX-NEXT: [[COMMON_RET_OP:%.*]] = phi bfloat [ [[MUL]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ] +; GFX-NEXT: [[COMMON_RET_OP:%.*]] = phi bfloat [ [[MUL:%.*]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ] ; GFX-NEXT: ret bfloat [[COMMON_RET_OP]] ; GFX: [[IF_THEN]]: +; GFX-NEXT: [[MUL]] = fmul contract bfloat [[X]], [[A_1]] ; GFX-NEXT: [[ADD:%.*]] = fadd contract bfloat 0xR3F80, [[MUL]] ; GFX-NEXT: br label %[[COMMON_RET]] ; GFX: [[IF_ELSE]]: -; GFX-NEXT: [[SUB]] = fsub contract bfloat [[MUL]], [[Y]] +; GFX-NEXT: [[MUL1:%.*]] = fmul contract bfloat [[X]], [[A_1]] +; GFX-NEXT: [[SUB]] = fsub contract bfloat [[MUL1]], [[Y]] ; GFX-NEXT: br label %[[COMMON_RET]] ; entry: @@ -330,16 +333,17 @@ define <8 x half> @is_profitable_vector(ptr dereferenceable(8) %ptr_x, ptr deref ; GFX-NEXT: [[V1:%.*]] = load <8 x half>, ptr addrspace(3) @v1_ptr, align 16 ; GFX-NEXT: [[V2:%.*]] = load <8 x half>, ptr addrspace(3) @v2_ptr, align 16 ; GFX-NEXT: [[CMP:%.*]] = fcmp oeq double [[Y]], 0.000000e+00 -; GFX-NEXT: [[MUL:%.*]] = fmul contract <8 x half> [[V1]], [[X]] ; GFX-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]] ; GFX: [[COMMON_RET:.*]]: ; GFX-NEXT: [[COMMON_RET_OP:%.*]] = phi <8 x half> [ [[ADD:%.*]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ] ; GFX-NEXT: ret <8 x half> [[COMMON_RET_OP]] ; GFX: [[IF_THEN]]: +; GFX-NEXT: [[MUL:%.*]] = fmul contract <8 x half> [[V1]], [[X]] ; GFX-NEXT: [[ADD]] = fadd contract <8 x half> [[V2]], [[MUL]] ; GFX-NEXT: br label %[[COMMON_RET]] ; GFX: [[IF_ELSE]]: -; GFX-NEXT: [[SUB]] = fsub contract <8 x half> [[MUL]], [[V2]] +; GFX-NEXT: [[MUL1:%.*]] = fmul contract <8 x half> [[V1]], [[X]] +; GFX-NEXT: [[SUB]] = fsub contract <8 x half> [[MUL1]], [[V2]] ; GFX-NEXT: br label %[[COMMON_RET]] ; entry: @@ -362,23 +366,61 @@ if.else: ; preds = %entry } define double @is_profitable_f64_nocontract(ptr dereferenceable(8) %ptr_x, ptr dereferenceable(8) %ptr_y, ptr dereferenceable(8) %ptr_a) #0 { -; GFX-LABEL: define double @is_profitable_f64_nocontract( -; GFX-SAME: ptr dereferenceable(8) [[PTR_X:%.*]], ptr dereferenceable(8) [[PTR_Y:%.*]], ptr dereferenceable(8) [[PTR_A:%.*]]) #[[ATTR0]] { -; GFX-NEXT: [[Y:%.*]] = load double, ptr [[PTR_Y]], align 8 -; GFX-NEXT: [[CMP:%.*]] = fcmp oeq double [[Y]], 0.000000e+00 -; GFX-NEXT: [[X:%.*]] = load double, ptr [[PTR_X]], align 8 -; GFX-NEXT: [[A_1:%.*]] = load double, ptr [[PTR_A]], align 8 -; GFX-NEXT: [[MUL:%.*]] = fmul double [[X]], [[A_1]] -; GFX-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]] -; GFX: [[COMMON_RET:.*]]: -; GFX-NEXT: [[COMMON_RET_OP:%.*]] = phi double [ [[PTR_ADD:%.*]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ] -; GFX-NEXT: ret double [[COMMON_RET_OP]] -; GFX: [[IF_THEN]]: -; GFX-NEXT: [[PTR_ADD]] = fadd double 1.000000e+00, [[MUL]] -; GFX-NEXT: br label %[[COMMON_RET]] -; GFX: [[IF_ELSE]]: -; GFX-NEXT: [[SUB]] = fsub double [[MUL]], [[Y]] -; GFX-NEXT: br label %[[COMMON_RET]] +; FP-CONTRACT-FAST-LABEL: define double @is_profitable_f64_nocontract( +; FP-CONTRACT-FAST-SAME: ptr dereferenceable(8) [[PTR_X:%.*]], ptr dereferenceable(8) [[PTR_Y:%.*]], ptr dereferenceable(8) [[PTR_A:%.*]]) #[[ATTR0]] { +; FP-CONTRACT-FAST-NEXT: [[Y:%.*]] = load double, ptr [[PTR_Y]], align 8 +; FP-CONTRACT-FAST-NEXT: [[CMP:%.*]] = fcmp oeq double [[Y]], 0.000000e+00 +; FP-CONTRACT-FAST-NEXT: [[X:%.*]] = load double, ptr [[PTR_X]], align 8 +; FP-CONTRACT-FAST-NEXT: [[A_1:%.*]] = load double, ptr [[PTR_A]], align 8 +; FP-CONTRACT-FAST-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]] +; FP-CONTRACT-FAST: [[COMMON_RET:.*]]: +; FP-CONTRACT-FAST-NEXT: [[COMMON_RET_OP:%.*]] = phi double [ [[PTR_ADD:%.*]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ] +; FP-CONTRACT-FAST-NEXT: ret double [[COMMON_RET_OP]] +; FP-CONTRACT-FAST: [[IF_THEN]]: +; FP-CONTRACT-FAST-NEXT: [[MUL:%.*]] = fmul double [[X]], [[A_1]] +; FP-CONTRACT-FAST-NEXT: [[PTR_ADD]] = fadd double 1.000000e+00, [[MUL]] +; FP-CONTRACT-FAST-NEXT: br label %[[COMMON_RET]] +; FP-CONTRACT-FAST: [[IF_ELSE]]: +; FP-CONTRACT-FAST-NEXT: [[MUL1:%.*]] = fmul double [[X]], [[A_1]] +; FP-CONTRACT-FAST-NEXT: [[SUB]] = fsub double [[MUL1]], [[Y]] +; FP-CONTRACT-FAST-NEXT: br label %[[COMMON_RET]] +; +; UNSAFE-FP-MATH-LABEL: define double @is_profitable_f64_nocontract( +; UNSAFE-FP-MATH-SAME: ptr dereferenceable(8) [[PTR_X:%.*]], ptr dereferenceable(8) [[PTR_Y:%.*]], ptr dereferenceable(8) [[PTR_A:%.*]]) #[[ATTR0]] { +; UNSAFE-FP-MATH-NEXT: [[Y:%.*]] = load double, ptr [[PTR_Y]], align 8 +; UNSAFE-FP-MATH-NEXT: [[CMP:%.*]] = fcmp oeq double [[Y]], 0.000000e+00 +; UNSAFE-FP-MATH-NEXT: [[X:%.*]] = load double, ptr [[PTR_X]], align 8 +; UNSAFE-FP-MATH-NEXT: [[A_1:%.*]] = load double, ptr [[PTR_A]], align 8 +; UNSAFE-FP-MATH-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]] +; UNSAFE-FP-MATH: [[COMMON_RET:.*]]: +; UNSAFE-FP-MATH-NEXT: [[COMMON_RET_OP:%.*]] = phi double [ [[PTR_ADD:%.*]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ] +; UNSAFE-FP-MATH-NEXT: ret double [[COMMON_RET_OP]] +; UNSAFE-FP-MATH: [[IF_THEN]]: +; UNSAFE-FP-MATH-NEXT: [[MUL:%.*]] = fmul double [[X]], [[A_1]] +; UNSAFE-FP-MATH-NEXT: [[PTR_ADD]] = fadd double 1.000000e+00, [[MUL]] +; UNSAFE-FP-MATH-NEXT: br label %[[COMMON_RET]] +; UNSAFE-FP-MATH: [[IF_ELSE]]: +; UNSAFE-FP-MATH-NEXT: [[MUL1:%.*]] = fmul double [[X]], [[A_1]] +; UNSAFE-FP-MATH-NEXT: [[SUB]] = fsub double [[MUL1]], [[Y]] +; UNSAFE-FP-MATH-NEXT: br label %[[COMMON_RET]] +; +; NO-UNSAFE-FP-MATH-LABEL: define double @is_profitable_f64_nocontract( +; NO-UNSAFE-FP-MATH-SAME: ptr dereferenceable(8) [[PTR_X:%.*]], ptr dereferenceable(8) [[PTR_Y:%.*]], ptr dereferenceable(8) [[PTR_A:%.*]]) #[[ATTR0]] { +; NO-UNSAFE-FP-MATH-NEXT: [[Y:%.*]] = load double, ptr [[PTR_Y]], align 8 +; NO-UNSAFE-FP-MATH-NEXT: [[CMP:%.*]] = fcmp oeq double [[Y]], 0.000000e+00 +; NO-UNSAFE-FP-MATH-NEXT: [[X:%.*]] = load double, ptr [[PTR_X]], align 8 +; NO-UNSAFE-FP-MATH-NEXT: [[A_1:%.*]] = load double, ptr [[PTR_A]], align 8 +; NO-UNSAFE-FP-MATH-NEXT: [[MUL:%.*]] = fmul double [[X]], [[A_1]] +; NO-UNSAFE-FP-MATH-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]] +; NO-UNSAFE-FP-MATH: [[COMMON_RET:.*]]: +; NO-UNSAFE-FP-MATH-NEXT: [[COMMON_RET_OP:%.*]] = phi double [ [[PTR_ADD:%.*]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ] +; NO-UNSAFE-FP-MATH-NEXT: ret double [[COMMON_RET_OP]] +; NO-UNSAFE-FP-MATH: [[IF_THEN]]: +; NO-UNSAFE-FP-MATH-NEXT: [[PTR_ADD]] = fadd double 1.000000e+00, [[MUL]] +; NO-UNSAFE-FP-MATH-NEXT: br label %[[COMMON_RET]] +; NO-UNSAFE-FP-MATH: [[IF_ELSE]]: +; NO-UNSAFE-FP-MATH-NEXT: [[SUB]] = fsub double [[MUL]], [[Y]] +; NO-UNSAFE-FP-MATH-NEXT: br label %[[COMMON_RET]] ; %y = load double, ptr %ptr_y, align 8 %cmp = fcmp oeq double %y, 0.000000e+00 @@ -400,7 +442,3 @@ if.else: ; preds = %entry attributes #0 = { nounwind "denormal-fp-math"="preserve-sign,preserve-sign" } attributes #1 = { nounwind "denormal-fp-math"="ieee,ieee" } -;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line: -; FP-CONTRACT-FAST: {{.*}} -; NO-UNSAFE-FP-MATH: {{.*}} -; UNSAFE-FP-MATH: {{.*}}