Skip to content

Commit

Permalink
JIT: More CSE heuristics adjustments (#98257)
Browse files Browse the repository at this point in the history
Based on analysis of cases where the machine learning is struggling, add some more observations
and tweak some of the existing ones:
* where we use `log` for dynamic compresson, bias results to they are always non-negative
* only consider integral vars for pressure estimate
* note if a CSE has a call
* note weighted tree costs
* note weighted local occurrences (approx pressure relief)
* note spread of occurrences (as fraction of BBs)
* note if CSE is something that can be contained (guess)
* note if CSE is cheap (cost 2 or 3) and is something that can be contained
* note if CSE might be "live across" a call in LSRA block ordering

The block spread and LSRA live across are using the RPO artifacts that may no longer be up to date.
Not clear it matters as LSRA does not use RPO for block ordering.

Contributes to #92915.
  • Loading branch information
AndyAyersMS authored Feb 10, 2024
1 parent af53dab commit 78bd7de
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 19 deletions.
105 changes: 87 additions & 18 deletions src/coreclr/jit/optcse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2513,17 +2513,26 @@ void CSE_HeuristicRL::CaptureLocalWeights()
LclVarDsc* const varDsc = m_pCompiler->lvaGetDescByTrackedIndex(trackedIndex);

// Locals with no references aren't enregistered
//
if (varDsc->lvRefCnt() == 0)
{
continue;
}

// Some LclVars always have stack homes
//
if (varDsc->lvDoNotEnregister)
{
continue;
}

// Only consider for integral types
//
if (varTypeIsFloating(varDsc->TypeGet()) || varTypeIsMask(varDsc->TypeGet()))
{
continue;
}

JITDUMP("V%02u," FMT_WT "\n", m_pCompiler->lvaGetLclNum(varDsc), varDsc->lvRefCntWtd());
m_localWeights->push_back(varDsc->lvRefCntWtd() / BB_UNITY_WEIGHT);
}
Expand Down Expand Up @@ -2723,6 +2732,17 @@ void CSE_HeuristicRL::SoftmaxPolicy()
// 14. cse is marked GTF_MAKE_CSE (0/1)
// 15. cse num distinct locals
// 16. cse num local occurrences
// 17. cse has call (0/1)
// 18. log (cse use count weighted * costEx)
// 19. log (cse use count weighted * num local occurrences)
// 20. cse "distance" (max postorder num - min postorder num) / num BBs
// 21. cse is "containable" (0/1)
// 22. cse is cheap & containable (0/1)
// 23. is live across call in possible LSRA ordering (0/1)
//
// -----
//
// 24. log (pressure estimate weight)
//
void CSE_HeuristicRL::GetFeatures(CSEdsc* cse, double* features)
{
Expand All @@ -2737,20 +2757,13 @@ void CSE_HeuristicRL::GetFeatures(CSEdsc* cse, double* features)
return;
}

const unsigned char costEx = cse->csdTree->GetCostEx();
const unsigned char costEx = cse->csdTree->GetCostEx();
const double deMinimis = 1e-3;
const double deMinimusAdj = -log(deMinimis);

features[0] = costEx;

if (cse->csdUseWtCnt > 0)
{
features[1] = log(cse->csdUseWtCnt);
}

if (cse->csdDefWtCnt > 0)
{
features[2] = log(cse->csdDefWtCnt);
}

features[1] = deMinimusAdj + log(max(deMinimis, cse->csdUseWtCnt));
features[2] = deMinimusAdj + log(max(deMinimis, cse->csdDefWtCnt));
features[3] = cse->csdTree->GetCostSz();
features[4] = cse->csdUseCount;
features[5] = cse->csdDefCount;
Expand All @@ -2770,6 +2783,7 @@ void CSE_HeuristicRL::GetFeatures(CSEdsc* cse, double* features)
features[9] = booleanScale * isSharedConstant;

const bool isMinCost = (costEx == Compiler::MIN_CSE_COST);
const bool isLowCost = (costEx <= Compiler::MIN_CSE_COST + 1);

features[10] = booleanScale * isMinCost;

Expand All @@ -2780,18 +2794,71 @@ void CSE_HeuristicRL::GetFeatures(CSEdsc* cse, double* features)
features[13] = booleanScale * (isMinCost & isLiveAcrossCall);

// Is any CSE tree for this candidate marked GTF_MAKE_CSE (hoisting)
// Also gather data for "distance" metric.
//
bool isMakeCse = false;
const unsigned numBBs = m_pCompiler->fgBBcount;
bool isMakeCse = false;
unsigned minPostorderNum = numBBs;
unsigned maxPostorderNum = 0;
BasicBlock* minPostorderBlock = nullptr;
BasicBlock* maxPostorderBlock = nullptr;
for (treeStmtLst* treeList = cse->csdTreeList; treeList != nullptr && !isMakeCse; treeList = treeList->tslNext)
{
isMakeCse = ((treeList->tslTree->gtFlags & GTF_MAKE_CSE) != 0);
BasicBlock* const treeBlock = treeList->tslBlock;
unsigned postorderNum = treeBlock->bbPostorderNum;
if (postorderNum < minPostorderNum)
{
minPostorderNum = postorderNum;
minPostorderBlock = treeBlock;
}

if (postorderNum > maxPostorderNum)
{
maxPostorderNum = postorderNum;
maxPostorderBlock = treeBlock;
}

isMakeCse |= ((treeList->tslTree->gtFlags & GTF_MAKE_CSE) != 0);
}
const unsigned blockSpread = maxPostorderNum - minPostorderNum;

features[14] = booleanScale * isMakeCse;

// Locals data
//
features[15] = cse->numDistinctLocals;
features[16] = cse->numLocalOccurrences;

// More
//
features[17] = booleanScale * ((cse->csdTree->gtFlags & GTF_CALL) != 0);
features[18] = deMinimusAdj + log(max(deMinimis, cse->csdUseCount * cse->csdUseWtCnt));
features[19] = deMinimusAdj + log(max(deMinimis, cse->numLocalOccurrences * cse->csdUseWtCnt));
features[20] = booleanScale * ((double)(blockSpread) / numBBs);

const bool isContainable = cse->csdTree->OperIs(GT_ADD, GT_NOT, GT_MUL, GT_LSH);
features[21] = booleanScale * isContainable;
features[22] = booleanScale * (isContainable && isLowCost);

// LSRA "is live across call"
//
bool isLiveAcrossCallLSRA = isLiveAcrossCall;

if (!isLiveAcrossCallLSRA)
{
unsigned count = 0;
for (BasicBlock *block = minPostorderBlock;
block != nullptr && block != maxPostorderBlock && count < blockSpread; block = block->Next(), count++)
{
if (block->HasFlag(BBF_HAS_CALL))
{
isLiveAcrossCallLSRA = true;
break;
}
}
}

features[23] = booleanScale * isLiveAcrossCallLSRA;
}

//------------------------------------------------------------------------
Expand All @@ -2804,7 +2871,7 @@ void CSE_HeuristicRL::GetFeatures(CSEdsc* cse, double* features)
//
// Stopping features
//
// 17. int register pressure weight estimate (log)
// 24. int register pressure weight estimate (log)
//
// All boolean features are scaled up by booleanScale so their
// numeric range is similar to the non-boolean features
Expand All @@ -2819,8 +2886,9 @@ void CSE_HeuristicRL::GetStoppingFeatures(double* features)
// "remove" weight per local use occurrences * weightUses
// "add" weight of the CSE temp times * (weigh defs*2) + weightUses
//
double minWeight = 0.01;
double spillAtWeight = minWeight;
const double deMinimis = 1e-3;
double spillAtWeight = deMinimis;
const double deMinimusAdj = -log(deMinimis);

// Assume each already performed cse is occupying a registger
//
Expand All @@ -2845,7 +2913,8 @@ void CSE_HeuristicRL::GetStoppingFeatures(double* features)
// Large frame...?
// todo: scan all vars, not just tracked?
//
features[17] = log(max(spillAtWeight, minWeight));

features[24] = deMinimusAdj + log(max(deMinimis, spillAtWeight));
}

//------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/jit/optcse.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class CSE_HeuristicRL : public CSE_HeuristicCommon

enum
{
numParameters = 19,
numParameters = 25,
booleanScale = 5,
maxSteps = 65, // MAX_CSE_CNT + 1 (for stopping)
};
Expand Down

0 comments on commit 78bd7de

Please sign in to comment.