Skip to content

Commit

Permalink
[SCEV] Iteratively compute ranges for deeply nested expressions.
Browse files Browse the repository at this point in the history
At the moment, getRangeRef may overflow the stack for very deeply nested
expressions.

This patch introduces a new getRangeRefIter function, which first builds
a worklist of N-ary expressions and phi nodes, followed by their
operands iteratively.

getRangeRef has been extended to also take a Depth argument and it
switches to use getRangeRefIter once the depth reaches a certain
threshold.

This ensures compile-time is not impacted in general. Note that
the iterative algorithm may lead to a slightly different evaluation
order, which could result in slightly worse ranges for cyclic phis.

https://llvm-compile-time-tracker.com/compare.php?from=23c3eb7cdf3478c9db86f6cb5115821a8f0f5f40&to=e0e09fa338e77e53242bfc846e1484350ad79773&stat=instructions

Fixes llvm#49579.

Reviewed By: mkazantsev

Differential Revision: https://reviews.llvm.org/D130728
  • Loading branch information
fhahn committed Nov 28, 2022
1 parent fcaaa82 commit c38d16c
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 20 deletions.
10 changes: 9 additions & 1 deletion llvm/include/llvm/Analysis/ScalarEvolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -1283,6 +1283,9 @@ class ScalarEvolution {
/// Mark SCEVUnknown Phis currently being processed by getRangeRef.
SmallPtrSet<const PHINode *, 6> PendingPhiRanges;

/// Mark SCEVUnknown Phis currently being processed by getRangeRefIter.
SmallPtrSet<const PHINode *, 6> PendingPhiRangesIter;

// Mark SCEVUnknown Phis currently being processed by isImpliedViaMerge.
SmallPtrSet<const PHINode *, 6> PendingMerges;

Expand Down Expand Up @@ -1560,7 +1563,12 @@ class ScalarEvolution {
/// Determine the range for a particular SCEV.
/// NOTE: This returns a reference to an entry in a cache. It must be
/// copied if its needed for longer.
const ConstantRange &getRangeRef(const SCEV *S, RangeSignHint Hint);
const ConstantRange &getRangeRef(const SCEV *S, RangeSignHint Hint,
unsigned Depth = 0);

/// Determine the range for a particular SCEV, but evaluates ranges for
/// operands iteratively first.
const ConstantRange &getRangeRefIter(const SCEV *S, RangeSignHint Hint);

/// Determines the range for the affine SCEVAddRecExpr {\p Start,+,\p Step}.
/// Helper for \c getRange.
Expand Down
110 changes: 91 additions & 19 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,11 @@ static cl::opt<unsigned>
cl::desc("Size of the expression which is considered huge"),
cl::init(4096));

static cl::opt<unsigned> RangeIterThreshold(
"scev-range-iter-threshold", cl::Hidden,
cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
cl::init(32));

static cl::opt<bool>
ClassifyExpressions("scalar-evolution-classify-expressions",
cl::Hidden, cl::init(true),
Expand Down Expand Up @@ -6425,18 +6430,78 @@ getRangeForUnknownRecurrence(const SCEVUnknown *U) {
return FullSet;
}

const ConstantRange &
ScalarEvolution::getRangeRefIter(const SCEV *S,
ScalarEvolution::RangeSignHint SignHint) {
DenseMap<const SCEV *, ConstantRange> &Cache =
SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
: SignedRanges;
SmallVector<const SCEV *> WorkList;
SmallPtrSet<const SCEV *, 8> Seen;

// Add Expr to the worklist, if Expr is either an N-ary expression or a
// SCEVUnknown PHI node.
auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
if (!Seen.insert(Expr).second)
return;
if (Cache.find(Expr) != Cache.end())
return;
if (isa<SCEVNAryExpr>(Expr) || isa<SCEVUDivExpr>(Expr))
WorkList.push_back(Expr);
else if (auto *UnknownS = dyn_cast<SCEVUnknown>(Expr))
if (isa<PHINode>(UnknownS->getValue()))
WorkList.push_back(Expr);
};
AddToWorklist(S);

// Build worklist by queuing operands of N-ary expressions and phi nodes.
for (unsigned I = 0; I != WorkList.size(); ++I) {
const SCEV *P = WorkList[I];
if (auto *NaryS = dyn_cast<SCEVNAryExpr>(P)) {
for (const SCEV *Op : NaryS->operands())
AddToWorklist(Op);
} else if (auto *UDiv = dyn_cast<SCEVUDivExpr>(P)) {
AddToWorklist(UDiv->getLHS());
AddToWorklist(UDiv->getRHS());
} else {
auto *UnknownS = cast<SCEVUnknown>(P);
if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
if (!PendingPhiRangesIter.insert(P).second)
continue;
for (auto &Op : reverse(P->operands()))
AddToWorklist(getSCEV(Op));
}
}
}

if (!WorkList.empty()) {
// Use getRangeRef to compute ranges for items in the worklist in reverse
// order. This will force ranges for earlier operands to be computed before
// their users in most cases.
for (const SCEV *P :
reverse(make_range(WorkList.begin() + 1, WorkList.end()))) {
getRangeRef(P, SignHint);

if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
PendingPhiRangesIter.erase(P);
}
}

return getRangeRef(S, SignHint, 0);
}

/// Determine the range for a particular SCEV. If SignHint is
/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
/// with a "cleaner" unsigned (resp. signed) representation.
const ConstantRange &
ScalarEvolution::getRangeRef(const SCEV *S,
ScalarEvolution::RangeSignHint SignHint) {
const ConstantRange &ScalarEvolution::getRangeRef(
const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
DenseMap<const SCEV *, ConstantRange> &Cache =
SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
: SignedRanges;
ConstantRange::PreferredRangeType RangeType =
SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED
? ConstantRange::Unsigned : ConstantRange::Signed;
SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
: ConstantRange::Signed;

// See if we've computed this range already.
DenseMap<const SCEV *, ConstantRange>::iterator I = Cache.find(S);
Expand All @@ -6446,6 +6511,11 @@ ScalarEvolution::getRangeRef(const SCEV *S,
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
return setRange(C, SignHint, ConstantRange(C->getAPInt()));

// Switch to iteratively computing the range for S, if it is part of a deeply
// nested expression.
if (Depth > RangeIterThreshold)
return getRangeRefIter(S, SignHint);

unsigned BitWidth = getTypeSizeInBits(S->getType());
ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
using OBO = OverflowingBinaryOperator;
Expand All @@ -6465,23 +6535,23 @@ ScalarEvolution::getRangeRef(const SCEV *S,
}

if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
ConstantRange X = getRangeRef(Add->getOperand(0), SignHint);
ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
unsigned WrapType = OBO::AnyWrap;
if (Add->hasNoSignedWrap())
WrapType |= OBO::NoSignedWrap;
if (Add->hasNoUnsignedWrap())
WrapType |= OBO::NoUnsignedWrap;
for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
X = X.addWithNoWrap(getRangeRef(Add->getOperand(i), SignHint),
X = X.addWithNoWrap(getRangeRef(Add->getOperand(i), SignHint, Depth + 1),
WrapType, RangeType);
return setRange(Add, SignHint,
ConservativeResult.intersectWith(X, RangeType));
}

if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint);
ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint));
X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint, Depth + 1));
return setRange(Mul, SignHint,
ConservativeResult.intersectWith(X, RangeType));
}
Expand All @@ -6507,41 +6577,42 @@ ScalarEvolution::getRangeRef(const SCEV *S,
}

const auto *NAry = cast<SCEVNAryExpr>(S);
ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint);
ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
X = X.intrinsic(ID, {X, getRangeRef(NAry->getOperand(i), SignHint)});
X = X.intrinsic(
ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
return setRange(S, SignHint,
ConservativeResult.intersectWith(X, RangeType));
}

if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint);
ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint);
ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
return setRange(UDiv, SignHint,
ConservativeResult.intersectWith(X.udiv(Y), RangeType));
}

if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint);
ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
return setRange(ZExt, SignHint,
ConservativeResult.intersectWith(X.zeroExtend(BitWidth),
RangeType));
}

if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
ConstantRange X = getRangeRef(SExt->getOperand(), SignHint);
ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
return setRange(SExt, SignHint,
ConservativeResult.intersectWith(X.signExtend(BitWidth),
RangeType));
}

if (const SCEVPtrToIntExpr *PtrToInt = dyn_cast<SCEVPtrToIntExpr>(S)) {
ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint);
ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
return setRange(PtrToInt, SignHint, X);
}

if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint);
ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
return setRange(Trunc, SignHint,
ConservativeResult.intersectWith(X.truncate(BitWidth),
RangeType));
Expand Down Expand Up @@ -6671,12 +6742,13 @@ ScalarEvolution::getRangeRef(const SCEV *S,
RangeType);

// A range of Phi is a subset of union of all ranges of its input.
if (const PHINode *Phi = dyn_cast<PHINode>(U->getValue())) {
if (PHINode *Phi = dyn_cast<PHINode>(U->getValue())) {
// Make sure that we do not run over cycled Phis.
if (PendingPhiRanges.insert(Phi).second) {
ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);

for (const auto &Op : Phi->operands()) {
auto OpRange = getRangeRef(getSCEV(Op), SignHint);
auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
RangeFromOps = RangeFromOps.unionWith(OpRange);
// No point to continue if we already have a full set.
if (RangeFromOps.isFullSet())
Expand Down
1 change: 1 addition & 0 deletions llvm/test/Analysis/ScalarEvolution/ranges.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py
; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" 2>&1 | FileCheck %s
; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" -scev-range-iter-threshold=1 2>&1 | FileCheck %s

target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64"

Expand Down
73 changes: 73 additions & 0 deletions llvm/test/Transforms/IndVarSimplify/range-iter-threshold.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
; RUN: opt -passes=indvars -S %s | FileCheck --check-prefix=COMMON --check-prefix=DEFAULT %s
; RUN: opt -passes=indvars -scev-range-iter-threshold=1 -S %s | FileCheck --check-prefix=COMMON --check-prefix=LIMIT %s

target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"

define i32 @test(i1 %c.0, i32 %m) {
; COMMON-LABEL: @test(
; COMMON-NEXT: entry:
; COMMON-NEXT: br label [[OUTER_HEADER:%.*]]
; COMMON: outer.header:
; DEFAULT-NEXT: [[INDVARS_IV:%.*]] = phi i32 [ [[INDVARS_IV_NEXT:%.*]], [[OUTER_LATCH:%.*]] ], [ 2, [[ENTRY:%.*]] ]
; COMMON-NEXT: [[IV_1:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[IV_1_NEXT:%.*]], [[OUTER_LATCH:%.*]] ]
; COMMON-NEXT: [[MAX_0:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[MAX_1:%.*]], [[OUTER_LATCH]] ]
; COMMON-NEXT: [[TMP0:%.*]] = sext i32 [[IV_1]] to i64
; COMMON-NEXT: br label [[INNER_1:%.*]]
; COMMON: inner.1:
; COMMON-NEXT: [[C_1:%.*]] = icmp slt i64 0, [[TMP0]]
; COMMON-NEXT: br i1 [[C_1]], label [[INNER_1]], label [[INNER_2_HEADER_PREHEADER:%.*]]
; COMMON: inner.2.header.preheader:
; COMMON-NEXT: br label [[INNER_2_HEADER:%.*]]
; COMMON: inner.2.header:
; COMMON-NEXT: [[IV_3:%.*]] = phi i32 [ [[IV_3_NEXT:%.*]], [[INNER_2_LATCH:%.*]] ], [ 0, [[INNER_2_HEADER_PREHEADER]] ]
; COMMON-NEXT: br i1 [[C_0:%.*]], label [[OUTER_LATCH]], label [[INNER_2_LATCH]]
; COMMON: inner.2.latch:
; COMMON-NEXT: [[IV_3_NEXT]] = add i32 [[IV_3]], 1
; DEFAULT-NEXT: [[EXITCOND:%.*]] = icmp eq i32 [[IV_3_NEXT]], [[INDVARS_IV]]
; LIMIT-NEXT: [[EXITCOND:%.*]] = icmp ugt i32 [[IV_3]], [[IV_1]]
; COMMON-NEXT: br i1 [[EXITCOND]], label [[OUTER_LATCH]], label [[INNER_2_HEADER]]
; COMMON: outer.latch:
; COMMON-NEXT: [[MAX_1]] = phi i32 [ [[M:%.*]], [[INNER_2_LATCH]] ], [ 0, [[INNER_2_HEADER]] ]
; COMMON-NEXT: [[IV_1_NEXT]] = add nuw i32 [[IV_1]], 1
; COMMON-NEXT: [[C_3:%.*]] = icmp ugt i32 [[IV_1]], [[MAX_0]]
; DEFAULT-NEXT: [[INDVARS_IV_NEXT]] = add i32 [[INDVARS_IV]], 1
; COMMON-NEXT: br i1 [[C_3]], label [[EXIT:%.*]], label [[OUTER_HEADER]], !llvm.loop [[LOOP0:![0-9]+]]
; COMMON: exit:
; COMMON-NEXT: ret i32 0
;
entry:
br label %outer.header

outer.header:
%iv.1 = phi i32 [ 0, %entry ], [ %iv.1.next, %outer.latch ]
%iv.2 = phi i32 [ 0, %entry ], [ %iv.2.next , %outer.latch ]
%max.0 = phi i32 [ 0, %entry ], [ %max.1, %outer.latch ]
%0 = sext i32 %iv.1 to i64
br label %inner.1

inner.1:
%c.1 = icmp slt i64 0, %0
br i1 %c.1, label %inner.1, label %inner.2.header

inner.2.header:
%iv.3 = phi i32 [ 0, %inner.1 ], [ %iv.3.next, %inner.2.latch ]
br i1 %c.0, label %outer.latch, label %inner.2.latch

inner.2.latch:
%iv.3.next = add i32 %iv.3, 1
%c.2 = icmp ugt i32 %iv.3, %iv.2
br i1 %c.2, label %outer.latch, label %inner.2.header

outer.latch:
%max.1 = phi i32 [ %m, %inner.2.latch ], [ %iv.3, %inner.2.header ]
%iv.1.next = add i32 %iv.1, 1
%iv.2.next = add i32 %iv.2, 1
%c.3 = icmp ugt i32 %iv.2, %max.0
br i1 %c.3, label %exit, label %outer.header, !llvm.loop !0

exit:
ret i32 0
}

!0 = distinct !{!0, !1}
!1 = !{!"llvm.loop.mustprogress"}

0 comments on commit c38d16c

Please sign in to comment.