Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Add assumption to preserve deref info after sinking. #120888

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/include/llvm/IR/IRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2676,6 +2676,8 @@ class IRBuilderBase {
CallInst *CreateAlignmentAssumption(const DataLayout &DL, Value *PtrValue,
Value *Alignment,
Value *OffsetValue = nullptr);

CallInst *CreateDereferenceableAssumption(Value *PtrValue, unsigned Size);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attribute::getWithDereferenceableBytes uses uint64_t for the size argument.

};

/// This provides a uniform API for creating instructions and inserting
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/Transforms/InstCombine/InstCombine.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct InstCombineOptions {
// Verify that a fix point has been reached after MaxIterations.
bool VerifyFixpoint = false;
unsigned MaxIterations = InstCombineDefaultMaxIterations;
bool CleanupAssumptions = false;

InstCombineOptions() = default;

Expand All @@ -43,6 +44,11 @@ struct InstCombineOptions {
MaxIterations = Value;
return *this;
}

InstCombineOptions &setCleanupAssumptions(bool Value) {
CleanupAssumptions = Value;
return *this;
}
};

class InstCombinePass : public PassInfoMixin<InstCombinePass> {
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/IR/IRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,13 @@ CallInst *IRBuilderBase::CreateAlignmentAssumption(const DataLayout &DL,
return CreateAlignmentAssumptionHelper(DL, PtrValue, Alignment, OffsetValue);
}

CallInst *IRBuilderBase::CreateDereferenceableAssumption(Value *PtrValue,
unsigned Size) {
SmallVector<Value *, 4> Vals({PtrValue, getInt64(Size)});
OperandBundleDefT<Value *> AlignOpB("dereferenceable", Vals);
return CreateAssumption(ConstantInt::getTrue(getContext()), {AlignOpB});
}

IRBuilderDefaultInserter::~IRBuilderDefaultInserter() = default;
IRBuilderCallbackInserter::~IRBuilderCallbackInserter() = default;
IRBuilderFolder::~IRBuilderFolder() = default;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Passes/PassBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,8 @@ Expected<InstCombineOptions> parseInstCombineOptions(StringRef Params) {
ParamName).str(),
inconvertibleErrorCode());
Result.setMaxIterations((unsigned)MaxIterations.getZExtValue());
} else if (ParamName == "cleanup-assumptions") {
Result.setCleanupAssumptions(Enable);
} else {
return make_error<StringError>(
formatv("invalid InstCombine pass parameter '{0}' ", ParamName).str(),
Expand Down
18 changes: 12 additions & 6 deletions llvm/lib/Passes/PassBuilderPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,8 @@ void PassBuilder::addVectorPasses(OptimizationLevel Level,
FPM.addPass(LoopLoadEliminationPass());
}
// Cleanup after the loop optimization passes.
FPM.addPass(InstCombinePass());
FPM.addPass(
InstCombinePass(InstCombineOptions().setCleanupAssumptions(true)));

if (Level.getSpeedupLevel() > 1 && ExtraVectorizerPasses) {
ExtraFunctionPassManager<ShouldRunExtraVectorPasses> ExtraPasses;
Expand All @@ -1317,7 +1318,8 @@ void PassBuilder::addVectorPasses(OptimizationLevel Level,
// dead (or speculatable) control flows or more combining opportunities.
ExtraPasses.addPass(EarlyCSEPass());
ExtraPasses.addPass(CorrelatedValuePropagationPass());
ExtraPasses.addPass(InstCombinePass());
ExtraPasses.addPass(
InstCombinePass(InstCombineOptions().setCleanupAssumptions(true)));
LoopPassManager LPM;
LPM.addPass(LICMPass(PTO.LicmMssaOptCap, PTO.LicmMssaNoAccForPromotionCap,
/*AllowSpeculation=*/true));
Expand All @@ -1328,7 +1330,8 @@ void PassBuilder::addVectorPasses(OptimizationLevel Level,
/*UseBlockFrequencyInfo=*/true));
ExtraPasses.addPass(
SimplifyCFGPass(SimplifyCFGOptions().convertSwitchRangeToICmp(true)));
ExtraPasses.addPass(InstCombinePass());
ExtraPasses.addPass(
InstCombinePass(InstCombineOptions().setCleanupAssumptions(true)));
FPM.addPass(std::move(ExtraPasses));
}

Expand All @@ -1351,7 +1354,8 @@ void PassBuilder::addVectorPasses(OptimizationLevel Level,

if (IsFullLTO) {
FPM.addPass(SCCPPass());
FPM.addPass(InstCombinePass());
FPM.addPass(
InstCombinePass(InstCombineOptions().setCleanupAssumptions(true)));
FPM.addPass(BDCEPass());
}

Expand All @@ -1366,7 +1370,8 @@ void PassBuilder::addVectorPasses(OptimizationLevel Level,
FPM.addPass(VectorCombinePass());

if (!IsFullLTO) {
FPM.addPass(InstCombinePass());
FPM.addPass(
InstCombinePass(InstCombineOptions().setCleanupAssumptions(true)));
// Unroll small loops to hide loop backedge latency and saturate any
// parallel execution resources of an out-of-order processor. We also then
// need to clean up redundancies and loop invariant code.
Expand All @@ -1392,7 +1397,8 @@ void PassBuilder::addVectorPasses(OptimizationLevel Level,
}

FPM.addPass(InferAlignmentPass());
FPM.addPass(InstCombinePass());
FPM.addPass(
InstCombinePass(InstCombineOptions().setCleanupAssumptions(true)));

// This is needed for two reasons:
// 1. It works around problems that instcombine introduces, such as sinking
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Passes/PassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ FUNCTION_PASS_WITH_PARAMS(
[](InstCombineOptions Opts) { return InstCombinePass(Opts); },
parseInstCombineOptions,
"no-use-loop-info;use-loop-info;no-verify-fixpoint;verify-fixpoint;"
"max-iterations=N")
"max-iterations=N;cleanup-assumptions")
FUNCTION_PASS_WITH_PARAMS(
"loop-unroll", "LoopUnrollPass",
[](LoopUnrollOptions Opts) { return LoopUnrollPass(Opts); },
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3172,6 +3172,13 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
MaybeSimplifyHint(OBU.Inputs[0]);
MaybeSimplifyHint(OBU.Inputs[1]);
}

// Try to clean up some assumption that are not very useful after this
// point.
if (CleanupAssumptions && OBU.getTagName() == "dereferenceable") {
auto *New = CallBase::removeOperandBundle(II, OBU.getTagID());
return New;
}
}

// Convert nonnull assume like:
Expand Down
8 changes: 6 additions & 2 deletions llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,20 @@ class User;
class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
: public InstCombiner,
public InstVisitor<InstCombinerImpl, Instruction *> {
bool CleanupAssumptions = false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This option should be taken into account by LastRunTrackingAnalysis.


public:
InstCombinerImpl(InstructionWorklist &Worklist, BuilderTy &Builder,
bool MinimizeSize, AAResults *AA, AssumptionCache &AC,
TargetLibraryInfo &TLI, TargetTransformInfo &TTI,
DominatorTree &DT, OptimizationRemarkEmitter &ORE,
BlockFrequencyInfo *BFI, BranchProbabilityInfo *BPI,
ProfileSummaryInfo *PSI, const DataLayout &DL,
ReversePostOrderTraversal<BasicBlock *> &RPOT)
ReversePostOrderTraversal<BasicBlock *> &RPOT,
bool CleanupAssumptions)
: InstCombiner(Worklist, Builder, MinimizeSize, AA, AC, TLI, TTI, DT, ORE,
BFI, BPI, PSI, DL, RPOT) {}
BFI, BPI, PSI, DL, RPOT),
CleanupAssumptions(CleanupAssumptions) {}

virtual ~InstCombinerImpl() = default;

Expand Down
19 changes: 16 additions & 3 deletions llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AssumeBundleQueries.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/BlockFrequencyInfo.h"
Expand Down Expand Up @@ -4872,6 +4873,16 @@ bool InstCombinerImpl::tryToSinkInstruction(Instruction *I,
/// the new position.

BasicBlock::iterator InsertPos = DestBlock->getFirstInsertionPt();

if (!CleanupAssumptions && isa<LoadInst>(I)) {
// Preserve dereferenceable at original position.
// TODO: Only need to add this extra information if I doesn't always execute
// in the new position.
Builder.SetInsertPoint(I);
Value *Ptr = I->getOperand(0);
Builder.CreateDereferenceableAssumption(
Ptr, I->getType()->getScalarSizeInBits());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loads of vectors and pointers are not handled here.

}
I->moveBefore(*DestBlock, InsertPos);
++NumSunkInst;

Expand Down Expand Up @@ -5133,7 +5144,9 @@ bool InstCombinerImpl::run() {

for (Use &U : I->uses()) {
User *User = U.getUser();
if (User->isDroppable())
if (User->isDroppable() &&
(!I->getType()->isPointerTy() ||
!getKnowledgeForValue(I, Attribute::Dereferenceable, &AC)))
continue;
if (NumUsers > MaxSinkNumUsers)
return std::nullopt;
Expand Down Expand Up @@ -5524,7 +5537,7 @@ static bool combineInstructionsOverFunction(
<< F.getName() << "\n");

InstCombinerImpl IC(Worklist, Builder, F.hasMinSize(), AA, AC, TLI, TTI, DT,
ORE, BFI, BPI, PSI, DL, RPOT);
ORE, BFI, BPI, PSI, DL, RPOT, Opts.CleanupAssumptions);
IC.MaxArraySizeForCombine = MaxArraySize;
bool MadeChangeInThisIteration = IC.prepareWorklist(F);
MadeChangeInThisIteration |= IC.run();
Expand Down Expand Up @@ -5573,7 +5586,7 @@ PreservedAnalyses InstCombinePass::run(Function &F,
FunctionAnalysisManager &AM) {
auto &LRT = AM.getResult<LastRunTrackingAnalysis>(F);
// No changes since last InstCombine pass, exit early.
if (LRT.shouldSkip(&ID))
if (LRT.shouldSkip(&ID) && !Options.CleanupAssumptions)
return PreservedAnalyses::all();

auto &AC = AM.getResult<AssumptionAnalysis>(F);
Expand Down
15 changes: 10 additions & 5 deletions llvm/test/Transforms/InstCombine/assume.ll
Original file line number Diff line number Diff line change
Expand Up @@ -316,11 +316,12 @@ define i1 @nonnull3(ptr %a, i1 %control) {
;
; BUNDLES-LABEL: @nonnull3(
; BUNDLES-NEXT: entry:
; BUNDLES-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(ptr [[A:%.*]], i64 0) ]
; BUNDLES-NEXT: br i1 [[CONTROL:%.*]], label [[TAKEN:%.*]], label [[NOT_TAKEN:%.*]]
; BUNDLES: taken:
; BUNDLES-NEXT: ret i1 false
; BUNDLES: not_taken:
; BUNDLES-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A:%.*]], align 8
; BUNDLES-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A]], align 8
; BUNDLES-NEXT: [[RVAL_2:%.*]] = icmp sgt ptr [[LOAD]], null
; BUNDLES-NEXT: ret i1 [[RVAL_2]]
;
Expand Down Expand Up @@ -454,11 +455,12 @@ define i1 @nonnull3A(ptr %a, i1 %control) {
;
; BUNDLES-LABEL: @nonnull3A(
; BUNDLES-NEXT: entry:
; BUNDLES-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(ptr [[A:%.*]], i64 0) ]
; BUNDLES-NEXT: br i1 [[CONTROL:%.*]], label [[TAKEN:%.*]], label [[NOT_TAKEN:%.*]]
; BUNDLES: taken:
; BUNDLES-NEXT: ret i1 true
; BUNDLES: not_taken:
; BUNDLES-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A:%.*]], align 8
; BUNDLES-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A]], align 8
; BUNDLES-NEXT: [[RVAL_2:%.*]] = icmp sgt ptr [[LOAD]], null
; BUNDLES-NEXT: ret i1 [[RVAL_2]]
;
Expand All @@ -478,9 +480,10 @@ not_taken:
define i1 @nonnull3B(ptr %a, i1 %control) {
; CHECK-LABEL: @nonnull3B(
; CHECK-NEXT: entry:
; CHECK-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(ptr [[A:%.*]], i64 0) ]
; CHECK-NEXT: br i1 [[CONTROL:%.*]], label [[TAKEN:%.*]], label [[NOT_TAKEN:%.*]]
; CHECK: taken:
; CHECK-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A:%.*]], align 8
; CHECK-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A]], align 8
; CHECK-NEXT: [[CMP:%.*]] = icmp ne ptr [[LOAD]], null
; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) [ "nonnull"(ptr [[LOAD]]) ]
; CHECK-NEXT: ret i1 [[CMP]]
Expand All @@ -504,9 +507,10 @@ declare i1 @tmp1(i1)
define i1 @nonnull3C(ptr %a, i1 %control) {
; CHECK-LABEL: @nonnull3C(
; CHECK-NEXT: entry:
; CHECK-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(ptr [[A:%.*]], i64 0) ]
; CHECK-NEXT: br i1 [[CONTROL:%.*]], label [[TAKEN:%.*]], label [[NOT_TAKEN:%.*]]
; CHECK: taken:
; CHECK-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A:%.*]], align 8
; CHECK-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A]], align 8
; CHECK-NEXT: [[CMP:%.*]] = icmp ne ptr [[LOAD]], null
; CHECK-NEXT: [[CMP2:%.*]] = call i1 @tmp1(i1 [[CMP]])
; CHECK-NEXT: br label [[EXIT:%.*]]
Expand Down Expand Up @@ -534,9 +538,10 @@ not_taken:
define i1 @nonnull3D(ptr %a, i1 %control) {
; CHECK-LABEL: @nonnull3D(
; CHECK-NEXT: entry:
; CHECK-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(ptr [[A:%.*]], i64 0) ]
; CHECK-NEXT: br i1 [[CONTROL:%.*]], label [[TAKEN:%.*]], label [[NOT_TAKEN:%.*]]
; CHECK: taken:
; CHECK-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A:%.*]], align 8
; CHECK-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A]], align 8
; CHECK-NEXT: [[CMP:%.*]] = icmp ne ptr [[LOAD]], null
; CHECK-NEXT: [[CMP2:%.*]] = call i1 @tmp1(i1 [[CMP]])
; CHECK-NEXT: br label [[EXIT:%.*]]
Expand Down
12 changes: 8 additions & 4 deletions llvm/test/Transforms/InstCombine/select-cmp-br.ll
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ define void @test1(ptr %arg) {
; CHECK-NEXT: [[M:%.*]] = load ptr, ptr [[ARG:%.*]], align 8
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds nuw i8, ptr [[ARG]], i64 16
; CHECK-NEXT: [[N:%.*]] = load ptr, ptr [[TMP1]], align 8
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[M]], i64 72
; CHECK-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(ptr [[TMP2]], i64 0) ]
; CHECK-NEXT: [[TMP5_NOT:%.*]] = icmp eq ptr [[M]], [[N]]
; CHECK-NEXT: br i1 [[TMP5_NOT]], label [[BB8:%.*]], label [[BB10:%.*]]
; CHECK: bb:
Expand All @@ -22,7 +24,6 @@ define void @test1(ptr %arg) {
; CHECK-NEXT: tail call void @bar(ptr nonnull [[ARG]])
; CHECK-NEXT: br label [[BB:%.*]]
; CHECK: bb10:
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[M]], i64 72
; CHECK-NEXT: [[TMP4:%.*]] = load ptr, ptr [[TMP2]], align 8
; CHECK-NEXT: [[TMP11:%.*]] = tail call i64 [[TMP4]](ptr nonnull [[ARG]])
; CHECK-NEXT: br label [[BB]]
Expand Down Expand Up @@ -56,6 +57,8 @@ define void @test2(ptr %arg) {
; CHECK-NEXT: [[M:%.*]] = load ptr, ptr [[ARG:%.*]], align 8
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds nuw i8, ptr [[ARG]], i64 16
; CHECK-NEXT: [[N:%.*]] = load ptr, ptr [[TMP1]], align 8
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[M]], i64 72
; CHECK-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(ptr [[TMP2]], i64 0) ]
; CHECK-NEXT: [[TMP5:%.*]] = icmp eq ptr [[M]], [[N]]
; CHECK-NEXT: br i1 [[TMP5]], label [[BB10:%.*]], label [[BB8:%.*]]
; CHECK: bb:
Expand All @@ -64,7 +67,6 @@ define void @test2(ptr %arg) {
; CHECK-NEXT: tail call void @bar(ptr nonnull [[ARG]])
; CHECK-NEXT: br label [[BB:%.*]]
; CHECK: bb10:
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[M]], i64 72
; CHECK-NEXT: [[TMP4:%.*]] = load ptr, ptr [[TMP2]], align 8
; CHECK-NEXT: [[TMP11:%.*]] = tail call i64 [[TMP4]](ptr nonnull [[ARG]])
; CHECK-NEXT: br label [[BB]]
Expand Down Expand Up @@ -98,6 +100,8 @@ define void @test3(ptr %arg) {
; CHECK-NEXT: [[M:%.*]] = load ptr, ptr [[ARG:%.*]], align 8
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds nuw i8, ptr [[ARG]], i64 16
; CHECK-NEXT: [[N:%.*]] = load ptr, ptr [[TMP1]], align 8
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[M]], i64 72
; CHECK-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(ptr [[TMP2]], i64 0) ]
; CHECK-NEXT: [[TMP5:%.*]] = icmp eq ptr [[M]], [[N]]
; CHECK-NEXT: br i1 [[TMP5]], label [[BB8:%.*]], label [[BB10:%.*]]
; CHECK: bb:
Expand All @@ -106,7 +110,6 @@ define void @test3(ptr %arg) {
; CHECK-NEXT: tail call void @bar(ptr nonnull [[ARG]])
; CHECK-NEXT: br label [[BB:%.*]]
; CHECK: bb10:
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[M]], i64 72
; CHECK-NEXT: [[TMP4:%.*]] = load ptr, ptr [[TMP2]], align 8
; CHECK-NEXT: [[TMP11:%.*]] = tail call i64 [[TMP4]](ptr nonnull [[ARG]])
; CHECK-NEXT: br label [[BB]]
Expand Down Expand Up @@ -140,6 +143,8 @@ define void @test4(ptr %arg) {
; CHECK-NEXT: [[M:%.*]] = load ptr, ptr [[ARG:%.*]], align 8
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds nuw i8, ptr [[ARG]], i64 16
; CHECK-NEXT: [[N:%.*]] = load ptr, ptr [[TMP1]], align 8
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[M]], i64 72
; CHECK-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(ptr [[TMP2]], i64 0) ]
; CHECK-NEXT: [[TMP5:%.*]] = icmp eq ptr [[M]], [[N]]
; CHECK-NEXT: br i1 [[TMP5]], label [[BB10:%.*]], label [[BB8:%.*]]
; CHECK: bb:
Expand All @@ -148,7 +153,6 @@ define void @test4(ptr %arg) {
; CHECK-NEXT: tail call void @bar(ptr nonnull [[ARG]])
; CHECK-NEXT: br label [[BB:%.*]]
; CHECK: bb10:
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[M]], i64 72
; CHECK-NEXT: [[TMP4:%.*]] = load ptr, ptr [[TMP2]], align 8
; CHECK-NEXT: [[TMP11:%.*]] = tail call i64 [[TMP4]](ptr nonnull [[ARG]])
; CHECK-NEXT: br label [[BB]]
Expand Down
Loading
Loading