Skip to content

Commit

Permalink
[FMV][GlobalOpt] Statically resolve calls to versioned functions. (ll…
Browse files Browse the repository at this point in the history
…vm#87939)

To deduce whether the optimization is legal we need to compare the target
features between caller and callee versions. The criteria for bypassing
the resolver are the following:

 * If the callee's feature set is a subset of the caller's feature set,
   then the callee is a candidate for direct call.

 * Among such candidates the one of highest priority is the best match
   and it shall be picked, unless there is a version of the callee with
   higher priority than the best match which cannot be picked from a
   higher priority caller (directly or through the resolver).

 * For every higher priority callee version than the best match, there
   is a higher priority caller version whose feature set availability
   is implied by the callee's feature set.

Example:

Callers and Callees are ordered in decreasing priority.
The arrows indicate successful call redirections.

  Caller        Callee      Explanation
=========================================================================
mops+sve2 --+--> mops       all the callee versions are subsets of the
            |               caller but mops has the highest priority
            |
     mops --+    sve2       between mops and default callees, mops wins

      sve        sve        between sve and default callees, sve wins
                            but sve2 does not have a high priority caller

  default -----> default    sve (callee) implies sve (caller),
                            sve2(callee) implies sve (caller),
                            mops(callee) implies mops(caller)
  • Loading branch information
labrinea authored and DKLoehr committed Jan 17, 2025
1 parent 7822a34 commit 59ea6fc
Show file tree
Hide file tree
Showing 9 changed files with 608 additions and 10 deletions.
17 changes: 17 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1870,6 +1870,13 @@ class TargetTransformInfo {
/// false, but it shouldn't matter what it returns anyway.
bool hasArmWideBranch(bool Thumb) const;

/// Returns a bitmask constructed from the target-features or fmv-features
/// metadata of a function.
uint64_t getFeatureMask(const Function &F) const;

/// Returns true if this is an instance of a function with multiple versions.
bool isMultiversionedFunction(const Function &F) const;

/// \return The maximum number of function arguments the target supports.
unsigned getMaxNumArgs() const;

Expand Down Expand Up @@ -2312,6 +2319,8 @@ class TargetTransformInfo::Concept {
virtual VPLegalization
getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
virtual bool hasArmWideBranch(bool Thumb) const = 0;
virtual uint64_t getFeatureMask(const Function &F) const = 0;
virtual bool isMultiversionedFunction(const Function &F) const = 0;
virtual unsigned getMaxNumArgs() const = 0;
virtual unsigned getNumBytesToPadGlobalArray(unsigned Size,
Type *ArrayType) const = 0;
Expand Down Expand Up @@ -3144,6 +3153,14 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
return Impl.hasArmWideBranch(Thumb);
}

uint64_t getFeatureMask(const Function &F) const override {
return Impl.getFeatureMask(F);
}

bool isMultiversionedFunction(const Function &F) const override {
return Impl.isMultiversionedFunction(F);
}

unsigned getMaxNumArgs() const override {
return Impl.getMaxNumArgs();
}
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,10 @@ class TargetTransformInfoImplBase {

bool hasArmWideBranch(bool) const { return false; }

uint64_t getFeatureMask(const Function &F) const { return 0; }

bool isMultiversionedFunction(const Function &F) const { return false; }

unsigned getMaxNumArgs() const { return UINT_MAX; }

unsigned getNumBytesToPadGlobalArray(unsigned Size, Type *ArrayType) const {
Expand Down
13 changes: 8 additions & 5 deletions llvm/include/llvm/TargetParser/AArch64TargetParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,16 @@ void fillValidCPUArchList(SmallVectorImpl<StringRef> &Values);

bool isX18ReservedByDefault(const Triple &TT);

// Return the priority for a given set of FMV features.
// For a given set of feature names, which can be either target-features, or
// fmv-features metadata, expand their dependencies and then return a bitmask
// corresponding to the entries of AArch64::FeatPriorities.
uint64_t getFMVPriority(ArrayRef<StringRef> Features);

// For given feature names, return a bitmask corresponding to the entries of
// AArch64::CPUFeatures. The values in CPUFeatures are not bitmasks themselves,
// they are sequential (0, 1, 2, 3, ...). The resulting bitmask is used at
// runtime to test whether a certain FMV feature is available on the host.
// For a given set of FMV feature names, expand their dependencies and then
// return a bitmask corresponding to the entries of AArch64::CPUFeatures.
// The values in CPUFeatures are not bitmasks themselves, they are sequential
// (0, 1, 2, 3, ...). The resulting bitmask is used at runtime to test whether
// a certain FMV feature is available on the host.
uint64_t getCpuSupportsMask(ArrayRef<StringRef> Features);

void PrintSupportedExtensions();
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,14 @@ bool TargetTransformInfo::hasArmWideBranch(bool Thumb) const {
return TTIImpl->hasArmWideBranch(Thumb);
}

uint64_t TargetTransformInfo::getFeatureMask(const Function &F) const {
return TTIImpl->getFeatureMask(F);
}

bool TargetTransformInfo::isMultiversionedFunction(const Function &F) const {
return TTIImpl->isMultiversionedFunction(F);
}

unsigned TargetTransformInfo::getMaxNumArgs() const {
return TTIImpl->getMaxNumArgs();
}
Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "llvm/IR/IntrinsicsAArch64.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
#include "llvm/TargetParser/AArch64TargetParser.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
#include <algorithm>
Expand Down Expand Up @@ -248,6 +249,19 @@ static bool hasPossibleIncompatibleOps(const Function *F) {
return false;
}

uint64_t AArch64TTIImpl::getFeatureMask(const Function &F) const {
StringRef AttributeStr =
isMultiversionedFunction(F) ? "fmv-features" : "target-features";
StringRef FeatureStr = F.getFnAttribute(AttributeStr).getValueAsString();
SmallVector<StringRef, 8> Features;
FeatureStr.split(Features, ",");
return AArch64::getFMVPriority(Features);
}

bool AArch64TTIImpl::isMultiversionedFunction(const Function &F) const {
return F.hasFnAttribute("fmv-features");
}

bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
const Function *Callee) const {
SMEAttrs CallerAttrs(*Caller), CalleeAttrs(*Callee);
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
unsigned DefaultCallPenalty) const;

uint64_t getFeatureMask(const Function &F) const;

bool isMultiversionedFunction(const Function &F) const;

/// \name Scalar TTI Implementations
/// @{

Expand Down
31 changes: 26 additions & 5 deletions llvm/lib/TargetParser/AArch64TargetParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,33 @@ std::optional<AArch64::ArchInfo> AArch64::ArchInfo::findBySubArch(StringRef SubA
return {};
}

std::optional<AArch64::FMVInfo> lookupFMVByID(AArch64::ArchExtKind ExtID) {
for (const AArch64::FMVInfo &Info : AArch64::getFMVInfo())
if (Info.ID && *Info.ID == ExtID)
return Info;
return {};
}

uint64_t AArch64::getFMVPriority(ArrayRef<StringRef> Features) {
uint64_t Priority = 0;
for (StringRef Feature : Features)
if (std::optional<FMVInfo> Info = parseFMVExtension(Feature))
Priority |= (1ULL << Info->PriorityBit);
return Priority;
// Transitively enable the Arch Extensions which correspond to each feature.
ExtensionSet FeatureBits;
for (const StringRef Feature : Features) {
std::optional<FMVInfo> FMV = parseFMVExtension(Feature);
if (!FMV) {
if (std::optional<ExtensionInfo> Info = targetFeatureToExtension(Feature))
FMV = lookupFMVByID(Info->ID);
}
if (FMV && FMV->ID)
FeatureBits.enable(*FMV->ID);
}

// Construct a bitmask for all the transitively enabled Arch Extensions.
uint64_t PriorityMask = 0;
for (const FMVInfo &Info : getFMVInfo())
if (Info.ID && FeatureBits.Enabled.test(*Info.ID))
PriorityMask |= (1ULL << Info.PriorityBit);

return PriorityMask;
}

uint64_t AArch64::getCpuSupportsMask(ArrayRef<StringRef> Features) {
Expand Down
162 changes: 162 additions & 0 deletions llvm/lib/Transforms/IPO/GlobalOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2641,6 +2641,165 @@ DeleteDeadIFuncs(Module &M,
return Changed;
}

// Follows the use-def chain of \p V backwards until it finds a Function,
// in which case it collects in \p Versions. Return true on successful
// use-def chain traversal, false otherwise.
static bool collectVersions(TargetTransformInfo &TTI, Value *V,
SmallVectorImpl<Function *> &Versions) {
if (auto *F = dyn_cast<Function>(V)) {
if (!TTI.isMultiversionedFunction(*F))
return false;
Versions.push_back(F);
} else if (auto *Sel = dyn_cast<SelectInst>(V)) {
if (!collectVersions(TTI, Sel->getTrueValue(), Versions))
return false;
if (!collectVersions(TTI, Sel->getFalseValue(), Versions))
return false;
} else if (auto *Phi = dyn_cast<PHINode>(V)) {
for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I)
if (!collectVersions(TTI, Phi->getIncomingValue(I), Versions))
return false;
} else {
// Unknown instruction type. Bail.
return false;
}
return true;
}

// Bypass the IFunc Resolver of MultiVersioned functions when possible. To
// deduce whether the optimization is legal we need to compare the target
// features between caller and callee versions. The criteria for bypassing
// the resolver are the following:
//
// * If the callee's feature set is a subset of the caller's feature set,
// then the callee is a candidate for direct call.
//
// * Among such candidates the one of highest priority is the best match
// and it shall be picked, unless there is a version of the callee with
// higher priority than the best match which cannot be picked from a
// higher priority caller (directly or through the resolver).
//
// * For every higher priority callee version than the best match, there
// is a higher priority caller version whose feature set availability
// is implied by the callee's feature set.
//
static bool OptimizeNonTrivialIFuncs(
Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI) {
bool Changed = false;

// Cache containing the mask constructed from a function's target features.
DenseMap<Function *, uint64_t> FeatureMask;

for (GlobalIFunc &IF : M.ifuncs()) {
if (IF.isInterposable())
continue;

Function *Resolver = IF.getResolverFunction();
if (!Resolver)
continue;

if (Resolver->isInterposable())
continue;

TargetTransformInfo &TTI = GetTTI(*Resolver);

// Discover the callee versions.
SmallVector<Function *> Callees;
if (any_of(*Resolver, [&TTI, &Callees](BasicBlock &BB) {
if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator()))
if (!collectVersions(TTI, Ret->getReturnValue(), Callees))
return true;
return false;
}))
continue;

assert(!Callees.empty() && "Expecting successful collection of versions");

// Cache the feature mask for each callee.
for (Function *Callee : Callees) {
auto [It, Inserted] = FeatureMask.try_emplace(Callee);
if (Inserted)
It->second = TTI.getFeatureMask(*Callee);
}

// Sort the callee versions in decreasing priority order.
sort(Callees, [&](auto *LHS, auto *RHS) {
return FeatureMask[LHS] > FeatureMask[RHS];
});

// Find the callsites and cache the feature mask for each caller.
SmallVector<Function *> Callers;
DenseMap<Function *, SmallVector<CallBase *>> CallSites;
for (User *U : IF.users()) {
if (auto *CB = dyn_cast<CallBase>(U)) {
if (CB->getCalledOperand() == &IF) {
Function *Caller = CB->getFunction();
auto [FeatIt, FeatInserted] = FeatureMask.try_emplace(Caller);
if (FeatInserted)
FeatIt->second = TTI.getFeatureMask(*Caller);
auto [CallIt, CallInserted] = CallSites.try_emplace(Caller);
if (CallInserted)
Callers.push_back(Caller);
CallIt->second.push_back(CB);
}
}
}

// Sort the caller versions in decreasing priority order.
sort(Callers, [&](auto *LHS, auto *RHS) {
return FeatureMask[LHS] > FeatureMask[RHS];
});

auto implies = [](uint64_t A, uint64_t B) { return (A & B) == B; };

// Index to the highest priority candidate.
unsigned I = 0;
// Now try to redirect calls starting from higher priority callers.
for (Function *Caller : Callers) {
assert(I < Callees.size() && "Found callers of equal priority");

Function *Callee = Callees[I];
uint64_t CallerBits = FeatureMask[Caller];
uint64_t CalleeBits = FeatureMask[Callee];

// In the case of FMV callers, we know that all higher priority callers
// than the current one did not get selected at runtime, which helps
// reason about the callees (if they have versions that mandate presence
// of the features which we already know are unavailable on this target).
if (TTI.isMultiversionedFunction(*Caller)) {
// If the feature set of the caller implies the feature set of the
// highest priority candidate then it shall be picked. In case of
// identical sets advance the candidate index one position.
if (CallerBits == CalleeBits)
++I;
else if (!implies(CallerBits, CalleeBits)) {
// Keep advancing the candidate index as long as the caller's
// features are a subset of the current candidate's.
while (implies(CalleeBits, CallerBits)) {
if (++I == Callees.size())
break;
CalleeBits = FeatureMask[Callees[I]];
}
continue;
}
} else {
// We can't reason much about non-FMV callers. Just pick the highest
// priority callee if it matches, otherwise bail.
if (I > 0 || !implies(CallerBits, CalleeBits))
continue;
}
auto &Calls = CallSites[Caller];
for (CallBase *CS : Calls)
CS->setCalledOperand(Callee);
Changed = true;
}
if (IF.use_empty() ||
all_of(IF.users(), [](User *U) { return isa<GlobalAlias>(U); }))
NumIFuncsResolved++;
}
return Changed;
}

static bool
optimizeGlobalsInModule(Module &M, const DataLayout &DL,
function_ref<TargetLibraryInfo &(Function &)> GetTLI,
Expand Down Expand Up @@ -2707,6 +2866,9 @@ optimizeGlobalsInModule(Module &M, const DataLayout &DL,
// Optimize IFuncs whose callee's are statically known.
LocalChange |= OptimizeStaticIFuncs(M);

// Optimize IFuncs based on the target features of the caller.
LocalChange |= OptimizeNonTrivialIFuncs(M, GetTTI);

// Remove any IFuncs that are now dead.
LocalChange |= DeleteDeadIFuncs(M, NotDiscardableComdats);

Expand Down
Loading

0 comments on commit 59ea6fc

Please sign in to comment.