Skip to content

Commit

Permalink
[AMDGPU] Implement IR variant of isFMAFasterThanFMulAndFAdd (#121465)
Browse files Browse the repository at this point in the history
  • Loading branch information
chinmaydd authored Jan 10, 2025
1 parent 2ea34cd commit 211bcf6
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 35 deletions.
56 changes: 56 additions & 0 deletions llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5732,6 +5732,35 @@ bool SITargetLowering::isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
return false;
}

// Refer to comments added to the MIR variant of isFMAFasterThanFMulAndFAdd for
// specific details.
bool SITargetLowering::isFMAFasterThanFMulAndFAdd(const Function &F,
Type *Ty) const {
switch (Ty->getScalarSizeInBits()) {
case 16: {
SIModeRegisterDefaults Mode = SIModeRegisterDefaults(F, *Subtarget);
return Subtarget->has16BitInsts() &&
Mode.FP64FP16Denormals != DenormalMode::getPreserveSign();
}
case 32: {
if (!Subtarget->hasMadMacF32Insts())
return Subtarget->hasFastFMAF32();

SIModeRegisterDefaults Mode = SIModeRegisterDefaults(F, *Subtarget);
if (Mode.FP32Denormals != DenormalMode::getPreserveSign())
return Subtarget->hasFastFMAF32() || Subtarget->hasDLInsts();

return Subtarget->hasFastFMAF32() && Subtarget->hasDLInsts();
}
case 64:
return true;
default:
break;
}

return false;
}

bool SITargetLowering::isFMADLegal(const MachineInstr &MI, LLT Ty) const {
if (!Ty.isScalar())
return false;
Expand Down Expand Up @@ -16992,6 +17021,33 @@ bool SITargetLowering::checkForPhysRegDependency(
return false;
}

/// Check if it is profitable to hoist instruction in then/else to if.
bool SITargetLowering::isProfitableToHoist(Instruction *I) const {
if (!I->hasOneUse())
return true;

Instruction *User = I->user_back();
// TODO: Add more patterns that are not profitable to hoist and
// handle modifiers such as fabs and fneg
switch (I->getOpcode()) {
case Instruction::FMul: {
if (User->getOpcode() != Instruction::FSub &&
User->getOpcode() != Instruction::FAdd)
return true;

const TargetOptions &Options = getTargetMachine().Options;

return ((!I->hasAllowContract() || !User->hasAllowContract()) &&
Options.AllowFPOpFusion != FPOpFusion::Fast &&
!Options.UnsafeFPMath) ||
!isFMAFasterThanFMulAndFAdd(*I->getFunction(), User->getType());
}
default:
return true;
}
return true;
}

void SITargetLowering::emitExpandAtomicAddrSpacePredicate(
Instruction *AI) const {
// Given: atomicrmw fadd ptr %addr, float %val ordering
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/AMDGPU/SIISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ class SITargetLowering final : public AMDGPUTargetLowering {
EVT VT) const override;
bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
const LLT Ty) const override;
bool isFMAFasterThanFMulAndFAdd(const Function &F, Type *Ty) const override;
bool isFMADLegal(const SelectionDAG &DAG, const SDNode *N) const override;
bool isFMADLegal(const MachineInstr &MI, const LLT Ty) const override;

Expand Down Expand Up @@ -538,6 +539,8 @@ class SITargetLowering final : public AMDGPUTargetLowering {
const TargetInstrInfo *TII, unsigned &PhysReg,
int &Cost) const override;

bool isProfitableToHoist(Instruction *I) const override;

bool isKnownNeverNaNForTargetNode(SDValue Op,
const SelectionDAG &DAG,
bool SNaN = false,
Expand Down
108 changes: 73 additions & 35 deletions llvm/test/CodeGen/AMDGPU/prevent-fmul-hoist-ir.ll
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@ define double @is_profitable_f64_contract(ptr dereferenceable(8) %ptr_x, ptr der
; GFX-NEXT: [[CMP:%.*]] = fcmp oeq double [[Y]], 0.000000e+00
; GFX-NEXT: [[X:%.*]] = load double, ptr [[PTR_X]], align 8
; GFX-NEXT: [[A_1:%.*]] = load double, ptr [[PTR_A]], align 8
; GFX-NEXT: [[MUL:%.*]] = fmul contract double [[X]], [[A_1]]
; GFX-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]]
; GFX: [[COMMON_RET:.*]]:
; GFX-NEXT: [[COMMON_RET_OP:%.*]] = phi double [ [[ADD:%.*]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ]
; GFX-NEXT: ret double [[COMMON_RET_OP]]
; GFX: [[IF_THEN]]:
; GFX-NEXT: [[MUL:%.*]] = fmul contract double [[X]], [[A_1]]
; GFX-NEXT: [[ADD]] = fadd contract double 1.000000e+00, [[MUL]]
; GFX-NEXT: br label %[[COMMON_RET]]
; GFX: [[IF_ELSE]]:
; GFX-NEXT: [[SUB]] = fsub contract double [[MUL]], [[Y]]
; GFX-NEXT: [[MUL1:%.*]] = fmul contract double [[X]], [[A_1]]
; GFX-NEXT: [[SUB]] = fsub contract double [[MUL1]], [[Y]]
; GFX-NEXT: br label %[[COMMON_RET]]
;
entry:
Expand Down Expand Up @@ -93,16 +94,17 @@ define float @is_profitable_f32(ptr dereferenceable(8) %ptr_x, ptr dereferenceab
; GFX-NEXT: [[CMP:%.*]] = fcmp oeq float [[Y]], 0.000000e+00
; GFX-NEXT: [[X:%.*]] = load float, ptr [[PTR_X]], align 8
; GFX-NEXT: [[A_1:%.*]] = load float, ptr [[PTR_A]], align 8
; GFX-NEXT: [[MUL:%.*]] = fmul contract float [[X]], [[A_1]]
; GFX-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]]
; GFX: [[COMMON_RET:.*]]:
; GFX-NEXT: [[COMMON_RET_OP:%.*]] = phi float [ [[MUL]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ]
; GFX-NEXT: [[COMMON_RET_OP:%.*]] = phi float [ [[MUL:%.*]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ]
; GFX-NEXT: ret float [[COMMON_RET_OP]]
; GFX: [[IF_THEN]]:
; GFX-NEXT: [[MUL]] = fmul contract float [[X]], [[A_1]]
; GFX-NEXT: [[ADD:%.*]] = fadd contract float 1.000000e+00, [[MUL]]
; GFX-NEXT: br label %[[COMMON_RET]]
; GFX: [[IF_ELSE]]:
; GFX-NEXT: [[SUB]] = fsub contract float [[MUL]], [[Y]]
; GFX-NEXT: [[MUL1:%.*]] = fmul contract float [[X]], [[A_1]]
; GFX-NEXT: [[SUB]] = fsub contract float [[MUL1]], [[Y]]
; GFX-NEXT: br label %[[COMMON_RET]]
;
entry:
Expand All @@ -111,7 +113,6 @@ entry:
%x = load float, ptr %ptr_x, align 8
br i1 %cmp, label %if.then, label %if.else


if.then: ; preds = %entry
%a_1 = load float, ptr %ptr_a, align 8
%mul = fmul contract float %x, %a_1
Expand Down Expand Up @@ -172,16 +173,17 @@ define half @is_profitable_f16_ieee(ptr dereferenceable(8) %ptr_x, ptr dereferen
; GFX-NEXT: [[CMP:%.*]] = fcmp oeq half [[Y]], 0xH0000
; GFX-NEXT: [[X:%.*]] = load half, ptr [[PTR_X]], align 8
; GFX-NEXT: [[A_1:%.*]] = load half, ptr [[PTR_A]], align 8
; GFX-NEXT: [[MUL:%.*]] = fmul contract half [[X]], [[A_1]]
; GFX-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]]
; GFX: [[COMMON_RET:.*]]:
; GFX-NEXT: [[COMMON_RET_OP:%.*]] = phi half [ [[MUL]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ]
; GFX-NEXT: [[COMMON_RET_OP:%.*]] = phi half [ [[MUL:%.*]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ]
; GFX-NEXT: ret half [[COMMON_RET_OP]]
; GFX: [[IF_THEN]]:
; GFX-NEXT: [[MUL]] = fmul contract half [[X]], [[A_1]]
; GFX-NEXT: [[ADD:%.*]] = fadd contract half [[Y]], [[MUL]]
; GFX-NEXT: br label %[[COMMON_RET]]
; GFX: [[IF_ELSE]]:
; GFX-NEXT: [[SUB]] = fsub contract half [[MUL]], [[Y]]
; GFX-NEXT: [[MUL1:%.*]] = fmul contract half [[X]], [[A_1]]
; GFX-NEXT: [[SUB]] = fsub contract half [[MUL1]], [[Y]]
; GFX-NEXT: br label %[[COMMON_RET]]
;
entry:
Expand Down Expand Up @@ -250,16 +252,17 @@ define bfloat @is_profitable_bfloat_ieee(ptr dereferenceable(8) %ptr_x, ptr dere
; GFX-NEXT: [[CMP:%.*]] = fcmp oeq bfloat [[Y]], 0xR0000
; GFX-NEXT: [[X:%.*]] = load bfloat, ptr [[PTR_X]], align 8
; GFX-NEXT: [[A_1:%.*]] = load bfloat, ptr [[PTR_A]], align 8
; GFX-NEXT: [[MUL:%.*]] = fmul contract bfloat [[X]], [[A_1]]
; GFX-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]]
; GFX: [[COMMON_RET:.*]]:
; GFX-NEXT: [[COMMON_RET_OP:%.*]] = phi bfloat [ [[MUL]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ]
; GFX-NEXT: [[COMMON_RET_OP:%.*]] = phi bfloat [ [[MUL:%.*]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ]
; GFX-NEXT: ret bfloat [[COMMON_RET_OP]]
; GFX: [[IF_THEN]]:
; GFX-NEXT: [[MUL]] = fmul contract bfloat [[X]], [[A_1]]
; GFX-NEXT: [[ADD:%.*]] = fadd contract bfloat 0xR3F80, [[MUL]]
; GFX-NEXT: br label %[[COMMON_RET]]
; GFX: [[IF_ELSE]]:
; GFX-NEXT: [[SUB]] = fsub contract bfloat [[MUL]], [[Y]]
; GFX-NEXT: [[MUL1:%.*]] = fmul contract bfloat [[X]], [[A_1]]
; GFX-NEXT: [[SUB]] = fsub contract bfloat [[MUL1]], [[Y]]
; GFX-NEXT: br label %[[COMMON_RET]]
;
entry:
Expand Down Expand Up @@ -330,16 +333,17 @@ define <8 x half> @is_profitable_vector(ptr dereferenceable(8) %ptr_x, ptr deref
; GFX-NEXT: [[V1:%.*]] = load <8 x half>, ptr addrspace(3) @v1_ptr, align 16
; GFX-NEXT: [[V2:%.*]] = load <8 x half>, ptr addrspace(3) @v2_ptr, align 16
; GFX-NEXT: [[CMP:%.*]] = fcmp oeq double [[Y]], 0.000000e+00
; GFX-NEXT: [[MUL:%.*]] = fmul contract <8 x half> [[V1]], [[X]]
; GFX-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]]
; GFX: [[COMMON_RET:.*]]:
; GFX-NEXT: [[COMMON_RET_OP:%.*]] = phi <8 x half> [ [[ADD:%.*]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ]
; GFX-NEXT: ret <8 x half> [[COMMON_RET_OP]]
; GFX: [[IF_THEN]]:
; GFX-NEXT: [[MUL:%.*]] = fmul contract <8 x half> [[V1]], [[X]]
; GFX-NEXT: [[ADD]] = fadd contract <8 x half> [[V2]], [[MUL]]
; GFX-NEXT: br label %[[COMMON_RET]]
; GFX: [[IF_ELSE]]:
; GFX-NEXT: [[SUB]] = fsub contract <8 x half> [[MUL]], [[V2]]
; GFX-NEXT: [[MUL1:%.*]] = fmul contract <8 x half> [[V1]], [[X]]
; GFX-NEXT: [[SUB]] = fsub contract <8 x half> [[MUL1]], [[V2]]
; GFX-NEXT: br label %[[COMMON_RET]]
;
entry:
Expand All @@ -362,23 +366,61 @@ if.else: ; preds = %entry
}

define double @is_profitable_f64_nocontract(ptr dereferenceable(8) %ptr_x, ptr dereferenceable(8) %ptr_y, ptr dereferenceable(8) %ptr_a) #0 {
; GFX-LABEL: define double @is_profitable_f64_nocontract(
; GFX-SAME: ptr dereferenceable(8) [[PTR_X:%.*]], ptr dereferenceable(8) [[PTR_Y:%.*]], ptr dereferenceable(8) [[PTR_A:%.*]]) #[[ATTR0]] {
; GFX-NEXT: [[Y:%.*]] = load double, ptr [[PTR_Y]], align 8
; GFX-NEXT: [[CMP:%.*]] = fcmp oeq double [[Y]], 0.000000e+00
; GFX-NEXT: [[X:%.*]] = load double, ptr [[PTR_X]], align 8
; GFX-NEXT: [[A_1:%.*]] = load double, ptr [[PTR_A]], align 8
; GFX-NEXT: [[MUL:%.*]] = fmul double [[X]], [[A_1]]
; GFX-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]]
; GFX: [[COMMON_RET:.*]]:
; GFX-NEXT: [[COMMON_RET_OP:%.*]] = phi double [ [[PTR_ADD:%.*]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ]
; GFX-NEXT: ret double [[COMMON_RET_OP]]
; GFX: [[IF_THEN]]:
; GFX-NEXT: [[PTR_ADD]] = fadd double 1.000000e+00, [[MUL]]
; GFX-NEXT: br label %[[COMMON_RET]]
; GFX: [[IF_ELSE]]:
; GFX-NEXT: [[SUB]] = fsub double [[MUL]], [[Y]]
; GFX-NEXT: br label %[[COMMON_RET]]
; FP-CONTRACT-FAST-LABEL: define double @is_profitable_f64_nocontract(
; FP-CONTRACT-FAST-SAME: ptr dereferenceable(8) [[PTR_X:%.*]], ptr dereferenceable(8) [[PTR_Y:%.*]], ptr dereferenceable(8) [[PTR_A:%.*]]) #[[ATTR0]] {
; FP-CONTRACT-FAST-NEXT: [[Y:%.*]] = load double, ptr [[PTR_Y]], align 8
; FP-CONTRACT-FAST-NEXT: [[CMP:%.*]] = fcmp oeq double [[Y]], 0.000000e+00
; FP-CONTRACT-FAST-NEXT: [[X:%.*]] = load double, ptr [[PTR_X]], align 8
; FP-CONTRACT-FAST-NEXT: [[A_1:%.*]] = load double, ptr [[PTR_A]], align 8
; FP-CONTRACT-FAST-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]]
; FP-CONTRACT-FAST: [[COMMON_RET:.*]]:
; FP-CONTRACT-FAST-NEXT: [[COMMON_RET_OP:%.*]] = phi double [ [[PTR_ADD:%.*]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ]
; FP-CONTRACT-FAST-NEXT: ret double [[COMMON_RET_OP]]
; FP-CONTRACT-FAST: [[IF_THEN]]:
; FP-CONTRACT-FAST-NEXT: [[MUL:%.*]] = fmul double [[X]], [[A_1]]
; FP-CONTRACT-FAST-NEXT: [[PTR_ADD]] = fadd double 1.000000e+00, [[MUL]]
; FP-CONTRACT-FAST-NEXT: br label %[[COMMON_RET]]
; FP-CONTRACT-FAST: [[IF_ELSE]]:
; FP-CONTRACT-FAST-NEXT: [[MUL1:%.*]] = fmul double [[X]], [[A_1]]
; FP-CONTRACT-FAST-NEXT: [[SUB]] = fsub double [[MUL1]], [[Y]]
; FP-CONTRACT-FAST-NEXT: br label %[[COMMON_RET]]
;
; UNSAFE-FP-MATH-LABEL: define double @is_profitable_f64_nocontract(
; UNSAFE-FP-MATH-SAME: ptr dereferenceable(8) [[PTR_X:%.*]], ptr dereferenceable(8) [[PTR_Y:%.*]], ptr dereferenceable(8) [[PTR_A:%.*]]) #[[ATTR0]] {
; UNSAFE-FP-MATH-NEXT: [[Y:%.*]] = load double, ptr [[PTR_Y]], align 8
; UNSAFE-FP-MATH-NEXT: [[CMP:%.*]] = fcmp oeq double [[Y]], 0.000000e+00
; UNSAFE-FP-MATH-NEXT: [[X:%.*]] = load double, ptr [[PTR_X]], align 8
; UNSAFE-FP-MATH-NEXT: [[A_1:%.*]] = load double, ptr [[PTR_A]], align 8
; UNSAFE-FP-MATH-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]]
; UNSAFE-FP-MATH: [[COMMON_RET:.*]]:
; UNSAFE-FP-MATH-NEXT: [[COMMON_RET_OP:%.*]] = phi double [ [[PTR_ADD:%.*]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ]
; UNSAFE-FP-MATH-NEXT: ret double [[COMMON_RET_OP]]
; UNSAFE-FP-MATH: [[IF_THEN]]:
; UNSAFE-FP-MATH-NEXT: [[MUL:%.*]] = fmul double [[X]], [[A_1]]
; UNSAFE-FP-MATH-NEXT: [[PTR_ADD]] = fadd double 1.000000e+00, [[MUL]]
; UNSAFE-FP-MATH-NEXT: br label %[[COMMON_RET]]
; UNSAFE-FP-MATH: [[IF_ELSE]]:
; UNSAFE-FP-MATH-NEXT: [[MUL1:%.*]] = fmul double [[X]], [[A_1]]
; UNSAFE-FP-MATH-NEXT: [[SUB]] = fsub double [[MUL1]], [[Y]]
; UNSAFE-FP-MATH-NEXT: br label %[[COMMON_RET]]
;
; NO-UNSAFE-FP-MATH-LABEL: define double @is_profitable_f64_nocontract(
; NO-UNSAFE-FP-MATH-SAME: ptr dereferenceable(8) [[PTR_X:%.*]], ptr dereferenceable(8) [[PTR_Y:%.*]], ptr dereferenceable(8) [[PTR_A:%.*]]) #[[ATTR0]] {
; NO-UNSAFE-FP-MATH-NEXT: [[Y:%.*]] = load double, ptr [[PTR_Y]], align 8
; NO-UNSAFE-FP-MATH-NEXT: [[CMP:%.*]] = fcmp oeq double [[Y]], 0.000000e+00
; NO-UNSAFE-FP-MATH-NEXT: [[X:%.*]] = load double, ptr [[PTR_X]], align 8
; NO-UNSAFE-FP-MATH-NEXT: [[A_1:%.*]] = load double, ptr [[PTR_A]], align 8
; NO-UNSAFE-FP-MATH-NEXT: [[MUL:%.*]] = fmul double [[X]], [[A_1]]
; NO-UNSAFE-FP-MATH-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]]
; NO-UNSAFE-FP-MATH: [[COMMON_RET:.*]]:
; NO-UNSAFE-FP-MATH-NEXT: [[COMMON_RET_OP:%.*]] = phi double [ [[PTR_ADD:%.*]], %[[IF_THEN]] ], [ [[SUB:%.*]], %[[IF_ELSE]] ]
; NO-UNSAFE-FP-MATH-NEXT: ret double [[COMMON_RET_OP]]
; NO-UNSAFE-FP-MATH: [[IF_THEN]]:
; NO-UNSAFE-FP-MATH-NEXT: [[PTR_ADD]] = fadd double 1.000000e+00, [[MUL]]
; NO-UNSAFE-FP-MATH-NEXT: br label %[[COMMON_RET]]
; NO-UNSAFE-FP-MATH: [[IF_ELSE]]:
; NO-UNSAFE-FP-MATH-NEXT: [[SUB]] = fsub double [[MUL]], [[Y]]
; NO-UNSAFE-FP-MATH-NEXT: br label %[[COMMON_RET]]
;
%y = load double, ptr %ptr_y, align 8
%cmp = fcmp oeq double %y, 0.000000e+00
Expand All @@ -400,7 +442,3 @@ if.else: ; preds = %entry

attributes #0 = { nounwind "denormal-fp-math"="preserve-sign,preserve-sign" }
attributes #1 = { nounwind "denormal-fp-math"="ieee,ieee" }
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; FP-CONTRACT-FAST: {{.*}}
; NO-UNSAFE-FP-MATH: {{.*}}
; UNSAFE-FP-MATH: {{.*}}

0 comments on commit 211bcf6

Please sign in to comment.