Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use BlockFrequency type in more places (NFC) #68266

Merged
merged 2 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions llvm/include/llvm/Analysis/BlockFrequencyInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,19 @@ class BlockFrequencyInfo {
/// Returns the estimated profile count of \p Freq.
/// This uses the frequency \p Freq and multiplies it by
/// the enclosing function's count (if available) and returns the value.
std::optional<uint64_t> getProfileCountFromFreq(uint64_t Freq) const;
std::optional<uint64_t> getProfileCountFromFreq(BlockFrequency Freq) const;

/// Returns true if \p BB is an irreducible loop header
/// block. Otherwise false.
bool isIrrLoopHeader(const BasicBlock *BB);

// Set the frequency of the given basic block.
void setBlockFreq(const BasicBlock *BB, uint64_t Freq);
void setBlockFreq(const BasicBlock *BB, BlockFrequency Freq);

/// Set the frequency of \p ReferenceBB to \p Freq and scale the frequencies
/// of the blocks in \p BlocksToScale such that their frequencies relative
/// to \p ReferenceBB remain unchanged.
void setBlockFreqAndScale(const BasicBlock *ReferenceBB, uint64_t Freq,
void setBlockFreqAndScale(const BasicBlock *ReferenceBB, BlockFrequency Freq,
SmallPtrSetImpl<BasicBlock *> &BlocksToScale);

/// calculate - compute block frequency info for the given function.
Expand All @@ -94,13 +94,13 @@ class BlockFrequencyInfo {

// Print the block frequency Freq to OS using the current functions entry
// frequency to convert freq into a relative decimal form.
raw_ostream &printBlockFreq(raw_ostream &OS, const BlockFrequency Freq) const;
raw_ostream &printBlockFreq(raw_ostream &OS, BlockFrequency Freq) const;

// Convenience method that attempts to look up the frequency associated with
// BB and print it to OS.
raw_ostream &printBlockFreq(raw_ostream &OS, const BasicBlock *BB) const;

uint64_t getEntryFreq() const;
BlockFrequency getEntryFreq() const;
void releaseMemory();
void print(raw_ostream &OS) const;

Expand Down
20 changes: 10 additions & 10 deletions llvm/include/llvm/Analysis/BlockFrequencyInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -527,19 +527,18 @@ class BlockFrequencyInfoImplBase {
getBlockProfileCount(const Function &F, const BlockNode &Node,
bool AllowSynthetic = false) const;
std::optional<uint64_t>
getProfileCountFromFreq(const Function &F, uint64_t Freq,
getProfileCountFromFreq(const Function &F, BlockFrequency Freq,
bool AllowSynthetic = false) const;
bool isIrrLoopHeader(const BlockNode &Node);

void setBlockFreq(const BlockNode &Node, uint64_t Freq);
void setBlockFreq(const BlockNode &Node, BlockFrequency Freq);

raw_ostream &printBlockFreq(raw_ostream &OS, const BlockNode &Node) const;
raw_ostream &printBlockFreq(raw_ostream &OS,
const BlockFrequency &Freq) const;
raw_ostream &printBlockFreq(raw_ostream &OS, BlockFrequency Freq) const;

uint64_t getEntryFreq() const {
BlockFrequency getEntryFreq() const {
assert(!Freqs.empty());
return Freqs[0].Integer;
return BlockFrequency(Freqs[0].Integer);
}
};

Expand Down Expand Up @@ -1029,7 +1028,7 @@ template <class BT> class BlockFrequencyInfoImpl : BlockFrequencyInfoImplBase {
}

std::optional<uint64_t>
getProfileCountFromFreq(const Function &F, uint64_t Freq,
getProfileCountFromFreq(const Function &F, BlockFrequency Freq,
bool AllowSynthetic = false) const {
return BlockFrequencyInfoImplBase::getProfileCountFromFreq(F, Freq,
AllowSynthetic);
Expand All @@ -1039,7 +1038,7 @@ template <class BT> class BlockFrequencyInfoImpl : BlockFrequencyInfoImplBase {
return BlockFrequencyInfoImplBase::isIrrLoopHeader(getNode(BB));
}

void setBlockFreq(const BlockT *BB, uint64_t Freq);
void setBlockFreq(const BlockT *BB, BlockFrequency Freq);

void forgetBlock(const BlockT *BB) {
// We don't erase corresponding items from `Freqs`, `RPOT` and other to
Expand Down Expand Up @@ -1145,12 +1144,13 @@ void BlockFrequencyInfoImpl<BT>::calculate(const FunctionT &F,
// blocks and unknown blocks.
for (const BlockT &BB : F)
if (!Nodes.count(&BB))
setBlockFreq(&BB, 0);
setBlockFreq(&BB, BlockFrequency());
}
}

template <class BT>
void BlockFrequencyInfoImpl<BT>::setBlockFreq(const BlockT *BB, uint64_t Freq) {
void BlockFrequencyInfoImpl<BT>::setBlockFreq(const BlockT *BB,
BlockFrequency Freq) {
if (Nodes.count(BB))
BlockFrequencyInfoImplBase::setBlockFreq(getNode(BB), Freq);
else {
Expand Down
4 changes: 2 additions & 2 deletions llvm/include/llvm/Analysis/ProfileSummaryInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class ProfileSummaryInfo {

template <typename BFIT>
bool isColdBlock(BlockFrequency BlockFreq, const BFIT *BFI) const {
auto Count = BFI->getProfileCountFromFreq(BlockFreq.getFrequency());
auto Count = BFI->getProfileCountFromFreq(BlockFreq);
return Count && isColdCount(*Count);
}

Expand Down Expand Up @@ -315,7 +315,7 @@ class ProfileSummaryInfo {
bool isHotOrColdBlockNthPercentile(int PercentileCutoff,
BlockFrequency BlockFreq,
BFIT *BFI) const {
auto Count = BFI->getProfileCountFromFreq(BlockFreq.getFrequency());
auto Count = BFI->getProfileCountFromFreq(BlockFreq);
if (isHot)
return Count && isHotCountNthPercentile(PercentileCutoff, *Count);
else
Expand Down
2 changes: 1 addition & 1 deletion llvm/include/llvm/CodeGen/GlobalISel/RegBankSelect.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ class RegBankSelect : public MachineFunctionPass {
public:
/// Create a MappingCost assuming that most of the instructions
/// will occur in a basic block with \p LocalFreq frequency.
MappingCost(const BlockFrequency &LocalFreq);
MappingCost(BlockFrequency LocalFreq);

/// Add \p Cost to the local cost.
/// \return true if this cost is saturated, false otherwise.
Expand Down
5 changes: 2 additions & 3 deletions llvm/include/llvm/CodeGen/MBFIWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@ class MBFIWrapper {

raw_ostream &printBlockFreq(raw_ostream &OS,
const MachineBasicBlock *MBB) const;
raw_ostream &printBlockFreq(raw_ostream &OS,
const BlockFrequency Freq) const;
raw_ostream &printBlockFreq(raw_ostream &OS, BlockFrequency Freq) const;
void view(const Twine &Name, bool isSimple = true);
uint64_t getEntryFreq() const;
BlockFrequency getEntryFreq() const;
const MachineBlockFrequencyInfo &getMBFI() { return MBFI; }

private:
Expand Down
9 changes: 5 additions & 4 deletions llvm/include/llvm/CodeGen/MachineBlockFrequencyInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,15 @@ class MachineBlockFrequencyInfo : public MachineFunctionPass {
/// Compute the frequency of the block, relative to the entry block.
/// This API assumes getEntryFreq() is non-zero.
double getBlockFreqRelativeToEntryBlock(const MachineBasicBlock *MBB) const {
assert(getEntryFreq() != 0 && "getEntryFreq() should not return 0 here!");
assert(getEntryFreq() != BlockFrequency(0) &&
"getEntryFreq() should not return 0 here!");
return static_cast<double>(getBlockFreq(MBB).getFrequency()) /
static_cast<double>(getEntryFreq());
static_cast<double>(getEntryFreq().getFrequency());
}

std::optional<uint64_t>
getBlockProfileCount(const MachineBasicBlock *MBB) const;
std::optional<uint64_t> getProfileCountFromFreq(uint64_t Freq) const;
std::optional<uint64_t> getProfileCountFromFreq(BlockFrequency Freq) const;

bool isIrrLoopHeader(const MachineBasicBlock *MBB) const;

Expand Down Expand Up @@ -101,7 +102,7 @@ class MachineBlockFrequencyInfo : public MachineFunctionPass {

/// Divide a block's BlockFrequency::getFrequency() value by this value to
/// obtain the entry block - relative frequency of said block.
uint64_t getEntryFreq() const;
BlockFrequency getEntryFreq() const;
};

} // end namespace llvm
Expand Down
9 changes: 7 additions & 2 deletions llvm/include/llvm/Support/BlockFrequency.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ class BlockFrequency {
uint64_t Frequency;

public:
BlockFrequency(uint64_t Freq = 0) : Frequency(Freq) { }
BlockFrequency() : Frequency(0) {}
MatzeB marked this conversation as resolved.
Show resolved Hide resolved
explicit BlockFrequency(uint64_t Freq) : Frequency(Freq) {}

/// Returns the maximum possible frequency, the saturation value.
static uint64_t getMaxFrequency() { return UINT64_MAX; }
static BlockFrequency max() { return BlockFrequency(UINT64_MAX); }

/// Returns the frequency as a fixpoint number scaled by the entry
/// frequency.
Expand Down Expand Up @@ -112,6 +113,10 @@ class BlockFrequency {
bool operator==(BlockFrequency RHS) const {
return Frequency == RHS.Frequency;
}

bool operator!=(BlockFrequency RHS) const {
return Frequency != RHS.Frequency;
}
};

} // namespace llvm
Expand Down
3 changes: 2 additions & 1 deletion llvm/include/llvm/Transforms/Instrumentation/CFGMST.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ template <class Edge, class BBInfo> class CFGMST {
LLVM_DEBUG(dbgs() << "Build Edge on " << F.getName() << "\n");

BasicBlock *Entry = &(F.getEntryBlock());
uint64_t EntryWeight = (BFI != nullptr ? BFI->getEntryFreq() : 2);
uint64_t EntryWeight =
(BFI != nullptr ? BFI->getEntryFreq().getFrequency() : 2);
// If we want to instrument the entry count, lower the weight to 0.
if (InstrumentFuncEntry)
EntryWeight = 0;
Expand Down
21 changes: 11 additions & 10 deletions llvm/lib/Analysis/BlockFrequencyInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ void BlockFrequencyInfo::calculate(const Function &F,
}

BlockFrequency BlockFrequencyInfo::getBlockFreq(const BasicBlock *BB) const {
return BFI ? BFI->getBlockFreq(BB) : 0;
return BFI ? BFI->getBlockFreq(BB) : BlockFrequency(0);
}

std::optional<uint64_t>
Expand All @@ -214,7 +214,7 @@ BlockFrequencyInfo::getBlockProfileCount(const BasicBlock *BB,
}

std::optional<uint64_t>
BlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const {
BlockFrequencyInfo::getProfileCountFromFreq(BlockFrequency Freq) const {
if (!BFI)
return std::nullopt;
return BFI->getProfileCountFromFreq(*getFunction(), Freq);
Expand All @@ -225,17 +225,18 @@ bool BlockFrequencyInfo::isIrrLoopHeader(const BasicBlock *BB) {
return BFI->isIrrLoopHeader(BB);
}

void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, uint64_t Freq) {
void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB,
BlockFrequency Freq) {
assert(BFI && "Expected analysis to be available");
BFI->setBlockFreq(BB, Freq);
}

void BlockFrequencyInfo::setBlockFreqAndScale(
const BasicBlock *ReferenceBB, uint64_t Freq,
const BasicBlock *ReferenceBB, BlockFrequency Freq,
SmallPtrSetImpl<BasicBlock *> &BlocksToScale) {
assert(BFI && "Expected analysis to be available");
// Use 128 bits APInt to avoid overflow.
APInt NewFreq(128, Freq);
APInt NewFreq(128, Freq.getFrequency());
APInt OldFreq(128, BFI->getBlockFreq(ReferenceBB).getFrequency());
APInt BBFreq(128, 0);
for (auto *BB : BlocksToScale) {
Expand All @@ -247,7 +248,7 @@ void BlockFrequencyInfo::setBlockFreqAndScale(
// a hot spot, one of the options proposed in
// https://reviews.llvm.org/D28535#650071 could be used to avoid this.
BBFreq = BBFreq.udiv(OldFreq);
BFI->setBlockFreq(BB, BBFreq.getLimitedValue());
BFI->setBlockFreq(BB, BlockFrequency(BBFreq.getLimitedValue()));
}
BFI->setBlockFreq(ReferenceBB, Freq);
}
Expand All @@ -266,8 +267,8 @@ const BranchProbabilityInfo *BlockFrequencyInfo::getBPI() const {
return BFI ? &BFI->getBPI() : nullptr;
}

raw_ostream &BlockFrequencyInfo::
printBlockFreq(raw_ostream &OS, const BlockFrequency Freq) const {
raw_ostream &BlockFrequencyInfo::printBlockFreq(raw_ostream &OS,
BlockFrequency Freq) const {
return BFI ? BFI->printBlockFreq(OS, Freq) : OS;
}

Expand All @@ -277,8 +278,8 @@ BlockFrequencyInfo::printBlockFreq(raw_ostream &OS,
return BFI ? BFI->printBlockFreq(OS, BB) : OS;
}

uint64_t BlockFrequencyInfo::getEntryFreq() const {
return BFI ? BFI->getEntryFreq() : 0;
BlockFrequency BlockFrequencyInfo::getEntryFreq() const {
return BFI ? BFI->getEntryFreq() : BlockFrequency(0);
}

void BlockFrequencyInfo::releaseMemory() { BFI.reset(); }
Expand Down
25 changes: 11 additions & 14 deletions llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,30 +581,27 @@ BlockFrequencyInfoImplBase::getBlockFreq(const BlockNode &Node) const {
report_fatal_error(OS.str());
}
#endif
return 0;
return BlockFrequency(0);
}
return Freqs[Node.Index].Integer;
return BlockFrequency(Freqs[Node.Index].Integer);
}

std::optional<uint64_t>
BlockFrequencyInfoImplBase::getBlockProfileCount(const Function &F,
const BlockNode &Node,
bool AllowSynthetic) const {
return getProfileCountFromFreq(F, getBlockFreq(Node).getFrequency(),
AllowSynthetic);
return getProfileCountFromFreq(F, getBlockFreq(Node), AllowSynthetic);
}

std::optional<uint64_t>
BlockFrequencyInfoImplBase::getProfileCountFromFreq(const Function &F,
uint64_t Freq,
bool AllowSynthetic) const {
std::optional<uint64_t> BlockFrequencyInfoImplBase::getProfileCountFromFreq(
const Function &F, BlockFrequency Freq, bool AllowSynthetic) const {
auto EntryCount = F.getEntryCount(AllowSynthetic);
if (!EntryCount)
return std::nullopt;
// Use 128 bit APInt to do the arithmetic to avoid overflow.
APInt BlockCount(128, EntryCount->getCount());
APInt BlockFreq(128, Freq);
APInt EntryFreq(128, getEntryFreq());
APInt BlockFreq(128, Freq.getFrequency());
APInt EntryFreq(128, getEntryFreq().getFrequency());
BlockCount *= BlockFreq;
// Rounded division of BlockCount by EntryFreq. Since EntryFreq is unsigned
// lshr by 1 gives EntryFreq/2.
Expand All @@ -627,10 +624,10 @@ BlockFrequencyInfoImplBase::getFloatingBlockFreq(const BlockNode &Node) const {
}

void BlockFrequencyInfoImplBase::setBlockFreq(const BlockNode &Node,
uint64_t Freq) {
BlockFrequency Freq) {
assert(Node.isValid() && "Expected valid node");
assert(Node.Index < Freqs.size() && "Expected legal index");
Freqs[Node.Index].Integer = Freq;
Freqs[Node.Index].Integer = Freq.getFrequency();
}

std::string
Expand All @@ -651,9 +648,9 @@ BlockFrequencyInfoImplBase::printBlockFreq(raw_ostream &OS,

raw_ostream &
BlockFrequencyInfoImplBase::printBlockFreq(raw_ostream &OS,
const BlockFrequency &Freq) const {
BlockFrequency Freq) const {
Scaled64 Block(Freq.getFrequency(), 0);
Scaled64 Entry(getEntryFreq(), 0);
Scaled64 Entry(getEntryFreq().getFrequency(), 0);

return OS << Block / Entry;
}
Expand Down
7 changes: 4 additions & 3 deletions llvm/lib/Analysis/CFGPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,11 @@ bool DOTGraphTraits<DOTFuncInfo *>::isNodeHidden(const BasicBlock *Node,
const DOTFuncInfo *CFGInfo) {
if (HideColdPaths.getNumOccurrences() > 0)
if (auto *BFI = CFGInfo->getBFI()) {
uint64_t NodeFreq = BFI->getBlockFreq(Node).getFrequency();
uint64_t EntryFreq = BFI->getEntryFreq();
BlockFrequency NodeFreq = BFI->getBlockFreq(Node);
BlockFrequency EntryFreq = BFI->getEntryFreq();
// Hide blocks with relative frequency below HideColdPaths threshold.
if ((double)NodeFreq / EntryFreq < HideColdPaths)
if ((double)NodeFreq.getFrequency() / EntryFreq.getFrequency() <
HideColdPaths)
return true;
}
if (HideUnreachablePaths || HideDeoptimizePaths) {
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ static void computeFunctionSummary(
// information.
if (BFI != nullptr && Hotness == CalleeInfo::HotnessType::Unknown) {
uint64_t BBFreq = BFI->getBlockFreq(&BB).getFrequency();
uint64_t EntryFreq = BFI->getEntryFreq();
uint64_t EntryFreq = BFI->getEntryFreq().getFrequency();
ValueInfo.updateRelBlockFreq(BBFreq, EntryFreq);
}
} else {
Expand Down
7 changes: 3 additions & 4 deletions llvm/lib/CodeGen/CodeGenPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2591,9 +2591,8 @@ bool CodeGenPrepare::dupRetToEnableTailCallOpts(BasicBlock *BB,
(void)FoldReturnIntoUncondBranch(RetI, BB, TailCallBB);
assert(!VerifyBFIUpdates ||
BFI->getBlockFreq(BB) >= BFI->getBlockFreq(TailCallBB));
BFI->setBlockFreq(
BB,
(BFI->getBlockFreq(BB) - BFI->getBlockFreq(TailCallBB)).getFrequency());
BFI->setBlockFreq(BB,
(BFI->getBlockFreq(BB) - BFI->getBlockFreq(TailCallBB)));
ModifiedDT = ModifyDT::ModifyBBDT;
Changed = true;
++NumRetsDup;
Expand Down Expand Up @@ -7067,7 +7066,7 @@ bool CodeGenPrepare::optimizeSelectInst(SelectInst *SI) {
FreshBBs.insert(EndBlock);
}

BFI->setBlockFreq(EndBlock, BFI->getBlockFreq(StartBlock).getFrequency());
BFI->setBlockFreq(EndBlock, BFI->getBlockFreq(StartBlock));

static const unsigned MD[] = {
LLVMContext::MD_prof, LLVMContext::MD_unpredictable,
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/CodeGen/GlobalISel/RegBankSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,8 @@ RegBankSelect::MappingCost RegBankSelect::computeMapping(
return MappingCost::ImpossibleCost();

// If mapped with InstrMapping, MI will have the recorded cost.
MappingCost Cost(MBFI ? MBFI->getBlockFreq(MI.getParent()) : 1);
MappingCost Cost(MBFI ? MBFI->getBlockFreq(MI.getParent())
: BlockFrequency(1));
bool Saturated = Cost.addLocalCost(InstrMapping.getCost());
assert(!Saturated && "Possible mapping saturated the cost");
LLVM_DEBUG(dbgs() << "Evaluating mapping cost for: " << MI);
Expand Down Expand Up @@ -971,7 +972,7 @@ bool RegBankSelect::EdgeInsertPoint::canMaterialize() const {
return Src.canSplitCriticalEdge(DstOrSplit);
}

RegBankSelect::MappingCost::MappingCost(const BlockFrequency &LocalFreq)
RegBankSelect::MappingCost::MappingCost(BlockFrequency LocalFreq)
: LocalFreq(LocalFreq.getFrequency()) {}

bool RegBankSelect::MappingCost::addLocalCost(uint64_t Cost) {
Expand Down
Loading