Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use BlockFrequency type in more places (NFC) #68266

Merged
merged 2 commits into from
Oct 5, 2023

Conversation

MatzeB
Copy link
Contributor

@MatzeB MatzeB commented Oct 4, 2023

The BlockFrequency class abstracts uint64_t frequency values. Use it more consistently in various APIs and disable implicit conversion to make usage more consistent and explicit.

  • Use BlockFrequency Freq parameter for setBlockFreq, getProfileCountFromFreq and setBlockFreqAndScale functions.
  • Return BlockFrequency in getEntryFreq() functions.
  • While on it change some const BlockFrequency& Freq parameters to plain BlockFreqency Freq.
  • Mark BlockFrequency(uint64_t) constructor as explicit.
  • Add missing BlockFrequency::operator!=.
  • Remove uint64_t BlockFreqency::getMaxFrequency().
  • Add BlockFrequency BlockFrequency::max() function.

@llvmbot
Copy link
Member

llvmbot commented Oct 4, 2023

@llvm/pr-subscribers-function-specialization
@llvm/pr-subscribers-llvm-transforms
@llvm/pr-subscribers-llvm-analysis
@llvm/pr-subscribers-llvm-regalloc
@llvm/pr-subscribers-llvm-support
@llvm/pr-subscribers-llvm-globalisel

@llvm/pr-subscribers-pgo

Changes

The BlockFrequency class abstracts uint64_t frequency values. Use it more consistently in various APIs and disable implicit conversion to make usage more consistent and explicit.

  • Use BlockFrequency Freq parameter for setBlockFreq, getProfileCountFromFreq and setBlockFreqAndScale functions.
  • Return BlockFrequency in getEntryFreq() functions.
  • While on it change some const BlockFrequency& Freq parameters to plain BlockFreqency Freq.
  • Mark BlockFrequency(uint64_t) constructor as explicit.
  • Add missing BlockFrequency::operator!=.
  • Remove uint64_t BlockFreqency::getMaxFrequency().
  • Add BlockFrequency BlockFrequency::max() function.

Patch is 55.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68266.diff

34 Files Affected:

  • (modified) llvm/include/llvm/Analysis/BlockFrequencyInfo.h (+4-4)
  • (modified) llvm/include/llvm/Analysis/BlockFrequencyInfoImpl.h (+10-10)
  • (modified) llvm/include/llvm/Analysis/ProfileSummaryInfo.h (+2-2)
  • (modified) llvm/include/llvm/CodeGen/MBFIWrapper.h (+1-1)
  • (modified) llvm/include/llvm/CodeGen/MachineBlockFrequencyInfo.h (+5-4)
  • (modified) llvm/include/llvm/Support/BlockFrequency.h (+7-2)
  • (modified) llvm/include/llvm/Transforms/Instrumentation/CFGMST.h (+2-1)
  • (modified) llvm/lib/Analysis/BlockFrequencyInfo.cpp (+9-8)
  • (modified) llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp (+11-14)
  • (modified) llvm/lib/Analysis/CFGPrinter.cpp (+4-3)
  • (modified) llvm/lib/Analysis/ModuleSummaryAnalysis.cpp (+1-1)
  • (modified) llvm/lib/CodeGen/CodeGenPrepare.cpp (+3-4)
  • (modified) llvm/lib/CodeGen/GlobalISel/RegBankSelect.cpp (+2-1)
  • (modified) llvm/lib/CodeGen/MBFIWrapper.cpp (+2-4)
  • (modified) llvm/lib/CodeGen/MachineBlockFrequencyInfo.cpp (+5-5)
  • (modified) llvm/lib/CodeGen/MachineBlockPlacement.cpp (+73-72)
  • (modified) llvm/lib/CodeGen/RegAllocGreedy.cpp (+12-11)
  • (modified) llvm/lib/CodeGen/SelectOptimize.cpp (+1-1)
  • (modified) llvm/lib/CodeGen/ShrinkWrap.cpp (+8-8)
  • (modified) llvm/lib/CodeGen/SpillPlacement.cpp (+12-8)
  • (modified) llvm/lib/CodeGen/SpillPlacement.h (+1-1)
  • (modified) llvm/lib/Target/PowerPC/PPCMIPeephole.cpp (+2-2)
  • (modified) llvm/lib/Transforms/IPO/FunctionSpecialization.cpp (+1-1)
  • (modified) llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp (+1-1)
  • (modified) llvm/lib/Transforms/Instrumentation/CGProfile.cpp (+1-1)
  • (modified) llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp (+1-1)
  • (modified) llvm/lib/Transforms/Scalar/JumpThreading.cpp (+5-5)
  • (modified) llvm/lib/Transforms/Scalar/LoopSink.cpp (+1-1)
  • (modified) llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp (+3-3)
  • (modified) llvm/lib/Transforms/Utils/CodeExtractor.cpp (+2-2)
  • (modified) llvm/lib/Transforms/Utils/InlineFunction.cpp (+5-6)
  • (modified) llvm/unittests/Analysis/BlockFrequencyInfoTest.cpp (+1-1)
  • (modified) llvm/unittests/Support/BlockFrequencyTest.cpp (+3-2)
  • (modified) llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp (+6-3)
diff --git a/llvm/include/llvm/Analysis/BlockFrequencyInfo.h b/llvm/include/llvm/Analysis/BlockFrequencyInfo.h
index 39507570a1b2c3f..4f0692bad6eeca5 100644
--- a/llvm/include/llvm/Analysis/BlockFrequencyInfo.h
+++ b/llvm/include/llvm/Analysis/BlockFrequencyInfo.h
@@ -73,19 +73,19 @@ class BlockFrequencyInfo {
   /// Returns the estimated profile count of \p Freq.
   /// This uses the frequency \p Freq and multiplies it by
   /// the enclosing function's count (if available) and returns the value.
-  std::optional<uint64_t> getProfileCountFromFreq(uint64_t Freq) const;
+  std::optional<uint64_t> getProfileCountFromFreq(BlockFrequency Freq) const;
 
   /// Returns true if \p BB is an irreducible loop header
   /// block. Otherwise false.
   bool isIrrLoopHeader(const BasicBlock *BB);
 
   // Set the frequency of the given basic block.
-  void setBlockFreq(const BasicBlock *BB, uint64_t Freq);
+  void setBlockFreq(const BasicBlock *BB, BlockFrequency Freq);
 
   /// Set the frequency of \p ReferenceBB to \p Freq and scale the frequencies
   /// of the blocks in \p BlocksToScale such that their frequencies relative
   /// to \p ReferenceBB remain unchanged.
-  void setBlockFreqAndScale(const BasicBlock *ReferenceBB, uint64_t Freq,
+  void setBlockFreqAndScale(const BasicBlock *ReferenceBB, BlockFrequency Freq,
                             SmallPtrSetImpl<BasicBlock *> &BlocksToScale);
 
   /// calculate - compute block frequency info for the given function.
@@ -100,7 +100,7 @@ class BlockFrequencyInfo {
   // BB and print it to OS.
   raw_ostream &printBlockFreq(raw_ostream &OS, const BasicBlock *BB) const;
 
-  uint64_t getEntryFreq() const;
+  BlockFrequency getEntryFreq() const;
   void releaseMemory();
   void print(raw_ostream &OS) const;
 
diff --git a/llvm/include/llvm/Analysis/BlockFrequencyInfoImpl.h b/llvm/include/llvm/Analysis/BlockFrequencyInfoImpl.h
index 54d56f8472c2bcc..b9c0e2759227539 100644
--- a/llvm/include/llvm/Analysis/BlockFrequencyInfoImpl.h
+++ b/llvm/include/llvm/Analysis/BlockFrequencyInfoImpl.h
@@ -527,19 +527,18 @@ class BlockFrequencyInfoImplBase {
   getBlockProfileCount(const Function &F, const BlockNode &Node,
                        bool AllowSynthetic = false) const;
   std::optional<uint64_t>
-  getProfileCountFromFreq(const Function &F, uint64_t Freq,
+  getProfileCountFromFreq(const Function &F, BlockFrequency Freq,
                           bool AllowSynthetic = false) const;
   bool isIrrLoopHeader(const BlockNode &Node);
 
-  void setBlockFreq(const BlockNode &Node, uint64_t Freq);
+  void setBlockFreq(const BlockNode &Node, BlockFrequency Freq);
 
   raw_ostream &printBlockFreq(raw_ostream &OS, const BlockNode &Node) const;
-  raw_ostream &printBlockFreq(raw_ostream &OS,
-                              const BlockFrequency &Freq) const;
+  raw_ostream &printBlockFreq(raw_ostream &OS, BlockFrequency Freq) const;
 
-  uint64_t getEntryFreq() const {
+  BlockFrequency getEntryFreq() const {
     assert(!Freqs.empty());
-    return Freqs[0].Integer;
+    return BlockFrequency(Freqs[0].Integer);
   }
 };
 
@@ -1029,7 +1028,7 @@ template <class BT> class BlockFrequencyInfoImpl : BlockFrequencyInfoImplBase {
   }
 
   std::optional<uint64_t>
-  getProfileCountFromFreq(const Function &F, uint64_t Freq,
+  getProfileCountFromFreq(const Function &F, BlockFrequency Freq,
                           bool AllowSynthetic = false) const {
     return BlockFrequencyInfoImplBase::getProfileCountFromFreq(F, Freq,
                                                                AllowSynthetic);
@@ -1039,7 +1038,7 @@ template <class BT> class BlockFrequencyInfoImpl : BlockFrequencyInfoImplBase {
     return BlockFrequencyInfoImplBase::isIrrLoopHeader(getNode(BB));
   }
 
-  void setBlockFreq(const BlockT *BB, uint64_t Freq);
+  void setBlockFreq(const BlockT *BB, BlockFrequency Freq);
 
   void forgetBlock(const BlockT *BB) {
     // We don't erase corresponding items from `Freqs`, `RPOT` and other to
@@ -1145,12 +1144,13 @@ void BlockFrequencyInfoImpl<BT>::calculate(const FunctionT &F,
     // blocks and unknown blocks.
     for (const BlockT &BB : F)
       if (!Nodes.count(&BB))
-        setBlockFreq(&BB, 0);
+        setBlockFreq(&BB, BlockFrequency());
   }
 }
 
 template <class BT>
-void BlockFrequencyInfoImpl<BT>::setBlockFreq(const BlockT *BB, uint64_t Freq) {
+void BlockFrequencyInfoImpl<BT>::setBlockFreq(const BlockT *BB,
+                                              BlockFrequency Freq) {
   if (Nodes.count(BB))
     BlockFrequencyInfoImplBase::setBlockFreq(getNode(BB), Freq);
   else {
diff --git a/llvm/include/llvm/Analysis/ProfileSummaryInfo.h b/llvm/include/llvm/Analysis/ProfileSummaryInfo.h
index 38eb71ba271d06b..e49538bfaf80fb1 100644
--- a/llvm/include/llvm/Analysis/ProfileSummaryInfo.h
+++ b/llvm/include/llvm/Analysis/ProfileSummaryInfo.h
@@ -209,7 +209,7 @@ class ProfileSummaryInfo {
 
   template <typename BFIT>
   bool isColdBlock(BlockFrequency BlockFreq, const BFIT *BFI) const {
-    auto Count = BFI->getProfileCountFromFreq(BlockFreq.getFrequency());
+    auto Count = BFI->getProfileCountFromFreq(BlockFreq);
     return Count && isColdCount(*Count);
   }
 
@@ -315,7 +315,7 @@ class ProfileSummaryInfo {
   bool isHotOrColdBlockNthPercentile(int PercentileCutoff,
                                      BlockFrequency BlockFreq,
                                      BFIT *BFI) const {
-    auto Count = BFI->getProfileCountFromFreq(BlockFreq.getFrequency());
+    auto Count = BFI->getProfileCountFromFreq(BlockFreq);
     if (isHot)
       return Count && isHotCountNthPercentile(PercentileCutoff, *Count);
     else
diff --git a/llvm/include/llvm/CodeGen/MBFIWrapper.h b/llvm/include/llvm/CodeGen/MBFIWrapper.h
index 714ecc5d4334e40..eca2aadb43d96cc 100644
--- a/llvm/include/llvm/CodeGen/MBFIWrapper.h
+++ b/llvm/include/llvm/CodeGen/MBFIWrapper.h
@@ -38,7 +38,7 @@ class MBFIWrapper {
   raw_ostream &printBlockFreq(raw_ostream &OS,
                               const BlockFrequency Freq) const;
   void view(const Twine &Name, bool isSimple = true);
-  uint64_t getEntryFreq() const;
+  BlockFrequency getEntryFreq() const;
   const MachineBlockFrequencyInfo &getMBFI() { return MBFI; }
 
  private:
diff --git a/llvm/include/llvm/CodeGen/MachineBlockFrequencyInfo.h b/llvm/include/llvm/CodeGen/MachineBlockFrequencyInfo.h
index 1152eefed6e45c3..2c15ed1732d9b47 100644
--- a/llvm/include/llvm/CodeGen/MachineBlockFrequencyInfo.h
+++ b/llvm/include/llvm/CodeGen/MachineBlockFrequencyInfo.h
@@ -66,14 +66,15 @@ class MachineBlockFrequencyInfo : public MachineFunctionPass {
   /// Compute the frequency of the block, relative to the entry block.
   /// This API assumes getEntryFreq() is non-zero.
   double getBlockFreqRelativeToEntryBlock(const MachineBasicBlock *MBB) const {
-    assert(getEntryFreq() != 0 && "getEntryFreq() should not return 0 here!");
+    assert(getEntryFreq() != BlockFrequency() &&
+           "getEntryFreq() should not return 0 here!");
     return static_cast<double>(getBlockFreq(MBB).getFrequency()) /
-           static_cast<double>(getEntryFreq());
+           static_cast<double>(getEntryFreq().getFrequency());
   }
 
   std::optional<uint64_t>
   getBlockProfileCount(const MachineBasicBlock *MBB) const;
-  std::optional<uint64_t> getProfileCountFromFreq(uint64_t Freq) const;
+  std::optional<uint64_t> getProfileCountFromFreq(BlockFrequency Freq) const;
 
   bool isIrrLoopHeader(const MachineBasicBlock *MBB) const;
 
@@ -101,7 +102,7 @@ class MachineBlockFrequencyInfo : public MachineFunctionPass {
 
   /// Divide a block's BlockFrequency::getFrequency() value by this value to
   /// obtain the entry block - relative frequency of said block.
-  uint64_t getEntryFreq() const;
+  BlockFrequency getEntryFreq() const;
 };
 
 } // end namespace llvm
diff --git a/llvm/include/llvm/Support/BlockFrequency.h b/llvm/include/llvm/Support/BlockFrequency.h
index 12a753301b36aba..8b172ee486aab85 100644
--- a/llvm/include/llvm/Support/BlockFrequency.h
+++ b/llvm/include/llvm/Support/BlockFrequency.h
@@ -26,10 +26,11 @@ class BlockFrequency {
   uint64_t Frequency;
 
 public:
-  BlockFrequency(uint64_t Freq = 0) : Frequency(Freq) { }
+  BlockFrequency() : Frequency(0) {}
+  explicit BlockFrequency(uint64_t Freq) : Frequency(Freq) {}
 
   /// Returns the maximum possible frequency, the saturation value.
-  static uint64_t getMaxFrequency() { return UINT64_MAX; }
+  static BlockFrequency max() { return BlockFrequency(UINT64_MAX); }
 
   /// Returns the frequency as a fixpoint number scaled by the entry
   /// frequency.
@@ -112,6 +113,10 @@ class BlockFrequency {
   bool operator==(BlockFrequency RHS) const {
     return Frequency == RHS.Frequency;
   }
+
+  bool operator!=(BlockFrequency RHS) const {
+    return Frequency != RHS.Frequency;
+  }
 };
 
 } // namespace llvm
diff --git a/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h b/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h
index 4d31898bb3147b6..269441db7a55896 100644
--- a/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h
+++ b/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h
@@ -101,7 +101,8 @@ template <class Edge, class BBInfo> class CFGMST {
     LLVM_DEBUG(dbgs() << "Build Edge on " << F.getName() << "\n");
 
     BasicBlock *Entry = &(F.getEntryBlock());
-    uint64_t EntryWeight = (BFI != nullptr ? BFI->getEntryFreq() : 2);
+    uint64_t EntryWeight =
+        (BFI != nullptr ? BFI->getEntryFreq().getFrequency() : 2);
     // If we want to instrument the entry count, lower the weight to 0.
     if (InstrumentFuncEntry)
       EntryWeight = 0;
diff --git a/llvm/lib/Analysis/BlockFrequencyInfo.cpp b/llvm/lib/Analysis/BlockFrequencyInfo.cpp
index b18d04cc73dbca0..c6cb5470463d353 100644
--- a/llvm/lib/Analysis/BlockFrequencyInfo.cpp
+++ b/llvm/lib/Analysis/BlockFrequencyInfo.cpp
@@ -201,7 +201,7 @@ void BlockFrequencyInfo::calculate(const Function &F,
 }
 
 BlockFrequency BlockFrequencyInfo::getBlockFreq(const BasicBlock *BB) const {
-  return BFI ? BFI->getBlockFreq(BB) : 0;
+  return BFI ? BFI->getBlockFreq(BB) : BlockFrequency(0);
 }
 
 std::optional<uint64_t>
@@ -214,7 +214,7 @@ BlockFrequencyInfo::getBlockProfileCount(const BasicBlock *BB,
 }
 
 std::optional<uint64_t>
-BlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const {
+BlockFrequencyInfo::getProfileCountFromFreq(BlockFrequency Freq) const {
   if (!BFI)
     return std::nullopt;
   return BFI->getProfileCountFromFreq(*getFunction(), Freq);
@@ -225,17 +225,18 @@ bool BlockFrequencyInfo::isIrrLoopHeader(const BasicBlock *BB) {
   return BFI->isIrrLoopHeader(BB);
 }
 
-void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, uint64_t Freq) {
+void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB,
+                                      BlockFrequency Freq) {
   assert(BFI && "Expected analysis to be available");
   BFI->setBlockFreq(BB, Freq);
 }
 
 void BlockFrequencyInfo::setBlockFreqAndScale(
-    const BasicBlock *ReferenceBB, uint64_t Freq,
+    const BasicBlock *ReferenceBB, BlockFrequency Freq,
     SmallPtrSetImpl<BasicBlock *> &BlocksToScale) {
   assert(BFI && "Expected analysis to be available");
   // Use 128 bits APInt to avoid overflow.
-  APInt NewFreq(128, Freq);
+  APInt NewFreq(128, Freq.getFrequency());
   APInt OldFreq(128, BFI->getBlockFreq(ReferenceBB).getFrequency());
   APInt BBFreq(128, 0);
   for (auto *BB : BlocksToScale) {
@@ -247,7 +248,7 @@ void BlockFrequencyInfo::setBlockFreqAndScale(
     // a hot spot, one of the options proposed in
     // https://reviews.llvm.org/D28535#650071 could be used to avoid this.
     BBFreq = BBFreq.udiv(OldFreq);
-    BFI->setBlockFreq(BB, BBFreq.getLimitedValue());
+    BFI->setBlockFreq(BB, BlockFrequency(BBFreq.getLimitedValue()));
   }
   BFI->setBlockFreq(ReferenceBB, Freq);
 }
@@ -277,8 +278,8 @@ BlockFrequencyInfo::printBlockFreq(raw_ostream &OS,
   return BFI ? BFI->printBlockFreq(OS, BB) : OS;
 }
 
-uint64_t BlockFrequencyInfo::getEntryFreq() const {
-  return BFI ? BFI->getEntryFreq() : 0;
+BlockFrequency BlockFrequencyInfo::getEntryFreq() const {
+  return BFI ? BFI->getEntryFreq() : BlockFrequency(0);
 }
 
 void BlockFrequencyInfo::releaseMemory() { BFI.reset(); }
diff --git a/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp b/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp
index 82b1e3b9eede709..583a038cc74e75d 100644
--- a/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp
+++ b/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp
@@ -581,30 +581,27 @@ BlockFrequencyInfoImplBase::getBlockFreq(const BlockNode &Node) const {
       report_fatal_error(OS.str());
     }
 #endif
-    return 0;
+    return BlockFrequency();
   }
-  return Freqs[Node.Index].Integer;
+  return BlockFrequency(Freqs[Node.Index].Integer);
 }
 
 std::optional<uint64_t>
 BlockFrequencyInfoImplBase::getBlockProfileCount(const Function &F,
                                                  const BlockNode &Node,
                                                  bool AllowSynthetic) const {
-  return getProfileCountFromFreq(F, getBlockFreq(Node).getFrequency(),
-                                 AllowSynthetic);
+  return getProfileCountFromFreq(F, getBlockFreq(Node), AllowSynthetic);
 }
 
-std::optional<uint64_t>
-BlockFrequencyInfoImplBase::getProfileCountFromFreq(const Function &F,
-                                                    uint64_t Freq,
-                                                    bool AllowSynthetic) const {
+std::optional<uint64_t> BlockFrequencyInfoImplBase::getProfileCountFromFreq(
+    const Function &F, BlockFrequency Freq, bool AllowSynthetic) const {
   auto EntryCount = F.getEntryCount(AllowSynthetic);
   if (!EntryCount)
     return std::nullopt;
   // Use 128 bit APInt to do the arithmetic to avoid overflow.
   APInt BlockCount(128, EntryCount->getCount());
-  APInt BlockFreq(128, Freq);
-  APInt EntryFreq(128, getEntryFreq());
+  APInt BlockFreq(128, Freq.getFrequency());
+  APInt EntryFreq(128, getEntryFreq().getFrequency());
   BlockCount *= BlockFreq;
   // Rounded division of BlockCount by EntryFreq. Since EntryFreq is unsigned
   // lshr by 1 gives EntryFreq/2.
@@ -627,10 +624,10 @@ BlockFrequencyInfoImplBase::getFloatingBlockFreq(const BlockNode &Node) const {
 }
 
 void BlockFrequencyInfoImplBase::setBlockFreq(const BlockNode &Node,
-                                              uint64_t Freq) {
+                                              BlockFrequency Freq) {
   assert(Node.isValid() && "Expected valid node");
   assert(Node.Index < Freqs.size() && "Expected legal index");
-  Freqs[Node.Index].Integer = Freq;
+  Freqs[Node.Index].Integer = Freq.getFrequency();
 }
 
 std::string
@@ -651,9 +648,9 @@ BlockFrequencyInfoImplBase::printBlockFreq(raw_ostream &OS,
 
 raw_ostream &
 BlockFrequencyInfoImplBase::printBlockFreq(raw_ostream &OS,
-                                           const BlockFrequency &Freq) const {
+                                           BlockFrequency Freq) const {
   Scaled64 Block(Freq.getFrequency(), 0);
-  Scaled64 Entry(getEntryFreq(), 0);
+  Scaled64 Entry(getEntryFreq().getFrequency(), 0);
 
   return OS << Block / Entry;
 }
diff --git a/llvm/lib/Analysis/CFGPrinter.cpp b/llvm/lib/Analysis/CFGPrinter.cpp
index f05dd6852d6dc93..9f55371f259b2cf 100644
--- a/llvm/lib/Analysis/CFGPrinter.cpp
+++ b/llvm/lib/Analysis/CFGPrinter.cpp
@@ -318,10 +318,11 @@ bool DOTGraphTraits<DOTFuncInfo *>::isNodeHidden(const BasicBlock *Node,
                                                  const DOTFuncInfo *CFGInfo) {
   if (HideColdPaths.getNumOccurrences() > 0)
     if (auto *BFI = CFGInfo->getBFI()) {
-      uint64_t NodeFreq = BFI->getBlockFreq(Node).getFrequency();
-      uint64_t EntryFreq = BFI->getEntryFreq();
+      BlockFrequency NodeFreq = BFI->getBlockFreq(Node);
+      BlockFrequency EntryFreq = BFI->getEntryFreq();
       // Hide blocks with relative frequency below HideColdPaths threshold.
-      if ((double)NodeFreq / EntryFreq < HideColdPaths)
+      if ((double)NodeFreq.getFrequency() / EntryFreq.getFrequency() <
+          HideColdPaths)
         return true;
     }
   if (HideUnreachablePaths || HideDeoptimizePaths) {
diff --git a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
index a88622efa12db8c..058a107691674ce 100644
--- a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
+++ b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
@@ -413,7 +413,7 @@ static void computeFunctionSummary(
         // information.
         if (BFI != nullptr && Hotness == CalleeInfo::HotnessType::Unknown) {
           uint64_t BBFreq = BFI->getBlockFreq(&BB).getFrequency();
-          uint64_t EntryFreq = BFI->getEntryFreq();
+          uint64_t EntryFreq = BFI->getEntryFreq().getFrequency();
           ValueInfo.updateRelBlockFreq(BBFreq, EntryFreq);
         }
       } else {
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index e31b08df7dbbe80..371f6598e6b2b35 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -2591,9 +2591,8 @@ bool CodeGenPrepare::dupRetToEnableTailCallOpts(BasicBlock *BB,
     (void)FoldReturnIntoUncondBranch(RetI, BB, TailCallBB);
     assert(!VerifyBFIUpdates ||
            BFI->getBlockFreq(BB) >= BFI->getBlockFreq(TailCallBB));
-    BFI->setBlockFreq(
-        BB,
-        (BFI->getBlockFreq(BB) - BFI->getBlockFreq(TailCallBB)).getFrequency());
+    BFI->setBlockFreq(BB,
+                      (BFI->getBlockFreq(BB) - BFI->getBlockFreq(TailCallBB)));
     ModifiedDT = ModifyDT::ModifyBBDT;
     Changed = true;
     ++NumRetsDup;
@@ -7067,7 +7066,7 @@ bool CodeGenPrepare::optimizeSelectInst(SelectInst *SI) {
     FreshBBs.insert(EndBlock);
   }
 
-  BFI->setBlockFreq(EndBlock, BFI->getBlockFreq(StartBlock).getFrequency());
+  BFI->setBlockFreq(EndBlock, BFI->getBlockFreq(StartBlock));
 
   static const unsigned MD[] = {
       LLVMContext::MD_prof, LLVMContext::MD_unpredictable,
diff --git a/llvm/lib/CodeGen/GlobalISel/RegBankSelect.cpp b/llvm/lib/CodeGen/GlobalISel/RegBankSelect.cpp
index d201342cd61dbc8..d19d5ab60305ed8 100644
--- a/llvm/lib/CodeGen/GlobalISel/RegBankSelect.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/RegBankSelect.cpp
@@ -449,7 +449,8 @@ RegBankSelect::MappingCost RegBankSelect::computeMapping(
     return MappingCost::ImpossibleCost();
 
   // If mapped with InstrMapping, MI will have the recorded cost.
-  MappingCost Cost(MBFI ? MBFI->getBlockFreq(MI.getParent()) : 1);
+  MappingCost Cost(MBFI ? MBFI->getBlockFreq(MI.getParent())
+                        : BlockFrequency(1));
   bool Saturated = Cost.addLocalCost(InstrMapping.getCost());
   assert(!Saturated && "Possible mapping saturated the cost");
   LLVM_DEBUG(dbgs() << "Evaluating mapping cost for: " << MI);
diff --git a/llvm/lib/CodeGen/MBFIWrapper.cpp b/llvm/lib/CodeGen/MBFIWrapper.cpp
index 5b388be27839464..351e38ebde5cece 100644
--- a/llvm/lib/CodeGen/MBFIWrapper.cpp
+++ b/llvm/lib/CodeGen/MBFIWrapper.cpp
@@ -38,7 +38,7 @@ MBFIWrapper::getBlockProfileCount(const MachineBasicBlock *MBB) const {
   // Modified block frequency also impacts profile count. So we should compute
   // profile count from new block frequency if it has been changed.
   if (I != MergedBBFreq.end())
-    return MBFI.getProfileCountFromFreq(I->second.getFrequency());
+    return MBFI.getProfileCountFromFreq(I->second);
 
   return MBFI.getBlockProfileCount(MBB);
 }
@@ -57,6 +57,4 @@ void MBFIWrapper::view(const Twine &Name, bool isSimple) {
   MBFI.view(Name, isSimple);
 }
 
-uint64_t MBFIWrapper::getEntryFreq() const {
-  return MBFI.getEntryFreq();
-}
+BlockFrequency MBFIWrapper::getEntryFreq() const { return MBFI.getEntryFreq(); }
diff --git a/llvm/lib/CodeGen/MachineBlockFrequencyInfo.cpp b/llvm/lib/CodeGen/MachineBlockFrequencyInfo.cpp
index b1cbe525d7e6c16..7d3a2c0e34bffe5 100644
--- a/llvm/lib/CodeGen/MachineBlockFrequencyInfo.cpp
+++ b/llvm/lib/CodeGen/MachineBlockFrequencyInfo.cpp
@@ -228,7 +228,7 @@ void MachineBlockFrequencyInfo::view(const Twine &Name, bool isSimple) const {
 
 BlockFrequency
 MachineBlockFrequencyInfo::getBlockFreq(const MachineBasicBlock *MBB) const {
-  return MBFI ? MBFI->getBlockFreq(MBB) : 0;
+  return MBFI ? MBFI->getBlockFreq(MBB) : BlockFrequency(0);
 }
 
 std::optional<uint64_t> MachineBlockFrequencyInfo::getBlockProfileCount(
@@ -241,7 +241,7 @@ std::optional<uint64_t> MachineBlockFrequencyInfo::getBlockProfi...
[truncated]

@MatzeB MatzeB force-pushed the use_more_blockfrequency branch from 5a7c8e1 to f116df2 Compare October 4, 2023 22:22
@github-actions
Copy link

github-actions bot commented Oct 4, 2023

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff cb7cf626d26e50887828d466c0187907719824d4 09bf86aad9e656ceee139a568e45b7721e04dbcc -- llvm/include/llvm/Analysis/BlockFrequencyInfo.h llvm/include/llvm/Analysis/BlockFrequencyInfoImpl.h llvm/include/llvm/Analysis/ProfileSummaryInfo.h llvm/include/llvm/CodeGen/GlobalISel/RegBankSelect.h llvm/include/llvm/CodeGen/MBFIWrapper.h llvm/include/llvm/CodeGen/MachineBlockFrequencyInfo.h llvm/include/llvm/Support/BlockFrequency.h llvm/include/llvm/Transforms/Instrumentation/CFGMST.h llvm/lib/Analysis/BlockFrequencyInfo.cpp llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp llvm/lib/Analysis/CFGPrinter.cpp llvm/lib/Analysis/ModuleSummaryAnalysis.cpp llvm/lib/CodeGen/CodeGenPrepare.cpp llvm/lib/CodeGen/GlobalISel/RegBankSelect.cpp llvm/lib/CodeGen/MBFIWrapper.cpp llvm/lib/CodeGen/MachineBlockFrequencyInfo.cpp llvm/lib/CodeGen/MachineBlockPlacement.cpp llvm/lib/CodeGen/RegAllocGreedy.cpp llvm/lib/CodeGen/SelectOptimize.cpp llvm/lib/CodeGen/ShrinkWrap.cpp llvm/lib/CodeGen/SpillPlacement.cpp llvm/lib/CodeGen/SpillPlacement.h llvm/lib/Target/PowerPC/PPCMIPeephole.cpp llvm/lib/Transforms/IPO/FunctionSpecialization.cpp llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp llvm/lib/Transforms/Instrumentation/CGProfile.cpp llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp llvm/lib/Transforms/Scalar/JumpThreading.cpp llvm/lib/Transforms/Scalar/LoopSink.cpp llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp llvm/lib/Transforms/Utils/CodeExtractor.cpp llvm/lib/Transforms/Utils/InlineFunction.cpp llvm/unittests/Analysis/BlockFrequencyInfoTest.cpp llvm/unittests/Support/BlockFrequencyTest.cpp llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/CodeGen/MBFIWrapper.cpp b/llvm/lib/CodeGen/MBFIWrapper.cpp
index 4f8b921b0318..2a1937756328 100644
--- a/llvm/lib/CodeGen/MBFIWrapper.cpp
+++ b/llvm/lib/CodeGen/MBFIWrapper.cpp
@@ -48,7 +48,7 @@ raw_ostream & MBFIWrapper::printBlockFreq(raw_ostream &OS,
   return MBFI.printBlockFreq(OS, getBlockFreq(MBB));
 }
 
-raw_ostream & MBFIWrapper::printBlockFreq(raw_ostream &OS,
+raw_ostream &MBFIWrapper::printBlockFreq(raw_ostream &OS,
                                          BlockFrequency Freq) const {
   return MBFI.printBlockFreq(OS, Freq);
 }

The `BlockFrequency` class abstracts `uint64_t` frequency values. Use it
more consistently in various APIs and disable implicit conversion to
make usage more consistent and explicit.

- Use `BlockFrequency Freq` parameter for `setBlockFreq`,
  `getProfileCountFromFreq` and `setBlockFreqAndScale` functions.
- Return `BlockFrequency` in `getEntryFreq()` functions.
- While on it change some `const BlockFrequency& Freq` parameters to
  plain `BlockFreqency Freq`.
- Mark `BlockFrequency(uint64_t)` constructor as explicit.
- Add missing `BlockFrequency::operator!=`.
- Remove `uint64_t BlockFreqency::getMaxFrequency()`.
- Add `BlockFrequency BlockFrequency::max()` function.
@MatzeB MatzeB force-pushed the use_more_blockfrequency branch from f116df2 to 9207db5 Compare October 4, 2023 22:31
@david-xl
Copy link
Contributor

david-xl commented Oct 5, 2023

Making the interfaces to use BlockFrequency consistently is good, but having programming convenience to directly assign from or compare with uint64_t is also handy. Perhaps add member functions?

@MatzeB
Copy link
Contributor Author

MatzeB commented Oct 5, 2023

There is already BlockFrequency::getFrequency() to access the uint64_t value. Generally I do not want to make things too convenient (so no implicit conversion to or from the type). Currently everytime code computes with the uint64_t directly there is a high chance of overflows not being handled (there is still a lot of bugs in that area), while for the BlockFrequency abstraction we get safer saturating behavior (and I have some personal patches that produce warnings on overflow within BlockFrequency that may be worth upstreaming later).

I also think it would be worth an experiment to use llvm::Scaled64 or double instead of uint64_t which is a lot easier to pull off when the abstraction is used in more places.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants