Skip to content

Commit fe66c7d

Browse files
committed
[FMV][GlobalOpt] Bypass the IFunc Resolver of MultiVersioned functions.
To deduce whether the optimization is legal we need to compare the target features between caller and callee versions. The criteria for bypassing the resolver are the following: * If the callee's feature set is a subset of the caller's feature set, then the callee is a candidate for direct call. * Among such candidates the one of highest priority is the best match and it shall be picked, unless there is a version of the callee with higher priority than the best match which cannot be picked from a higher priority caller (directly or through the resolver). * For every higher priority callee version than the best match, there is a higher priority caller version whose feature set availability is implied by the callee's feature set. Example: Callers and Callees are ordered in decreasing priority. The arrows indicate successful call redirections. Caller Callee Explanation ========================================================================= mops+sve2 --+--> mops all the callee versions are subsets of the | caller but mops has the highest priority | mops --+ sve2 between mops and default callees, mops wins sve sve between sve and default callees, sve wins but sve2 does not have a high priority caller default -----> default sve (callee) implies sve (caller), sve2(callee) implies sve (caller), mops(callee) implies mops(caller)
1 parent a522dbb commit fe66c7d

File tree

9 files changed

+602
-6
lines changed

9 files changed

+602
-6
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

+14
Original file line numberDiff line numberDiff line change
@@ -1762,6 +1762,12 @@ class TargetTransformInfo {
17621762
/// false, but it shouldn't matter what it returns anyway.
17631763
bool hasArmWideBranch(bool Thumb) const;
17641764

1765+
/// Returns true if the target supports Function MultiVersioning.
1766+
bool hasFMV() const;
1767+
1768+
/// Returns the MultiVersion priority of a given function.
1769+
uint64_t getFMVPriority(Function &F) const;
1770+
17651771
/// \return The maximum number of function arguments the target supports.
17661772
unsigned getMaxNumArgs() const;
17671773

@@ -2152,6 +2158,8 @@ class TargetTransformInfo::Concept {
21522158
virtual VPLegalization
21532159
getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
21542160
virtual bool hasArmWideBranch(bool Thumb) const = 0;
2161+
virtual bool hasFMV() const = 0;
2162+
virtual uint64_t getFMVPriority(Function &F) const = 0;
21552163
virtual unsigned getMaxNumArgs() const = 0;
21562164
};
21572165

@@ -2904,6 +2912,12 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
29042912
return Impl.hasArmWideBranch(Thumb);
29052913
}
29062914

2915+
bool hasFMV() const override { return Impl.hasFMV(); }
2916+
2917+
uint64_t getFMVPriority(Function &F) const override {
2918+
return Impl.getFMVPriority(F);
2919+
}
2920+
29072921
unsigned getMaxNumArgs() const override {
29082922
return Impl.getMaxNumArgs();
29092923
}

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

+4
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,10 @@ class TargetTransformInfoImplBase {
941941

942942
bool hasArmWideBranch(bool) const { return false; }
943943

944+
bool hasFMV() const { return false; }
945+
946+
uint64_t getFMVPriority(Function &F) const { return 0; }
947+
944948
unsigned getMaxNumArgs() const { return UINT_MAX; }
945949

946950
protected:

llvm/include/llvm/TargetParser/AArch64TargetParser.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,7 @@ const ArchInfo *getArchForCpu(StringRef CPU);
846846
// Parser
847847
const ArchInfo *parseArch(StringRef Arch);
848848
std::optional<ExtensionInfo> parseArchExtension(StringRef Extension);
849+
std::optional<ExtensionInfo> parseTargetFeature(StringRef Feature);
849850
// Given the name of a CPU or alias, return the correponding CpuInfo.
850851
std::optional<CpuInfo> parseCpu(StringRef Name);
851852
// Used by target parser tests
@@ -856,7 +857,8 @@ bool isX18ReservedByDefault(const Triple &TT);
856857
// For given feature names, return a bitmask corresponding to the entries of
857858
// AArch64::CPUFeatures. The values in CPUFeatures are not bitmasks
858859
// themselves, they are sequential (0, 1, 2, 3, ...).
859-
uint64_t getCpuSupportsMask(ArrayRef<StringRef> FeatureStrs);
860+
uint64_t getCpuSupportsMask(ArrayRef<StringRef> FeatureStrs,
861+
bool IsBackEndFeature = false);
860862

861863
void PrintSupportedExtensions(StringMap<StringRef> DescMap);
862864

llvm/lib/Analysis/TargetTransformInfo.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,12 @@ bool TargetTransformInfo::hasArmWideBranch(bool Thumb) const {
12961296
return TTIImpl->hasArmWideBranch(Thumb);
12971297
}
12981298

1299+
bool TargetTransformInfo::hasFMV() const { return TTIImpl->hasFMV(); }
1300+
1301+
uint64_t TargetTransformInfo::getFMVPriority(Function &F) const {
1302+
return TTIImpl->getFMVPriority(F);
1303+
}
1304+
12991305
unsigned TargetTransformInfo::getMaxNumArgs() const {
13001306
return TTIImpl->getMaxNumArgs();
13011307
}

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/IR/IntrinsicsAArch64.h"
2222
#include "llvm/IR/PatternMatch.h"
2323
#include "llvm/Support/Debug.h"
24+
#include "llvm/TargetParser/AArch64TargetParser.h"
2425
#include "llvm/Transforms/InstCombine/InstCombiner.h"
2526
#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
2627
#include <algorithm>
@@ -231,6 +232,13 @@ static bool hasPossibleIncompatibleOps(const Function *F) {
231232
return false;
232233
}
233234

235+
uint64_t AArch64TTIImpl::getFMVPriority(Function &F) const {
236+
StringRef FeatureStr = F.getFnAttribute("target-features").getValueAsString();
237+
SmallVector<StringRef, 8> Features;
238+
FeatureStr.split(Features, ",");
239+
return AArch64::getCpuSupportsMask(Features, /*IsBackEndFeature = */ true);
240+
}
241+
234242
bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
235243
const Function *Callee) const {
236244
SMEAttrs CallerAttrs(*Caller), CalleeAttrs(*Callee);

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

+4
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
8383
unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
8484
unsigned DefaultCallPenalty) const;
8585

86+
bool hasFMV() const { return ST->hasFMV(); }
87+
88+
uint64_t getFMVPriority(Function &F) const;
89+
8690
/// \name Scalar TTI Implementations
8791
/// @{
8892

llvm/lib/TargetParser/AArch64TargetParser.cpp

+13-4
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,13 @@ std::optional<AArch64::ArchInfo> AArch64::ArchInfo::findBySubArch(StringRef SubA
4747
return {};
4848
}
4949

50-
uint64_t AArch64::getCpuSupportsMask(ArrayRef<StringRef> FeatureStrs) {
50+
uint64_t AArch64::getCpuSupportsMask(ArrayRef<StringRef> FeatureStrs,
51+
bool IsBackEndFeature) {
5152
uint64_t FeaturesMask = 0;
52-
for (const StringRef &FeatureStr : FeatureStrs) {
53-
if (auto Ext = parseArchExtension(FeatureStr))
53+
for (const StringRef FeatureStr : FeatureStrs)
54+
if (auto Ext = IsBackEndFeature ? parseTargetFeature(FeatureStr)
55+
: parseArchExtension(FeatureStr))
5456
FeaturesMask |= (1ULL << Ext->CPUFeature);
55-
}
5657
return FeaturesMask;
5758
}
5859

@@ -132,6 +133,14 @@ std::optional<AArch64::ExtensionInfo> AArch64::parseArchExtension(StringRef Arch
132133
return {};
133134
}
134135

136+
std::optional<AArch64::ExtensionInfo>
137+
AArch64::parseTargetFeature(StringRef Feature) {
138+
for (const auto &E : Extensions)
139+
if (Feature == E.Feature)
140+
return E;
141+
return {};
142+
}
143+
135144
std::optional<AArch64::CpuInfo> AArch64::parseCpu(StringRef Name) {
136145
// Resolve aliases first.
137146
Name = resolveCPUAlias(Name);

llvm/lib/Transforms/IPO/GlobalOpt.cpp

+138-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ STATISTIC(NumAliasesRemoved, "Number of global aliases eliminated");
8989
STATISTIC(NumCXXDtorsRemoved, "Number of global C++ destructors removed");
9090
STATISTIC(NumInternalFunc, "Number of internal functions");
9191
STATISTIC(NumColdCC, "Number of functions marked coldcc");
92-
STATISTIC(NumIFuncsResolved, "Number of statically resolved IFuncs");
92+
STATISTIC(NumIFuncsResolved, "Number of resolved IFuncs");
9393
STATISTIC(NumIFuncsDeleted, "Number of IFuncs removed");
9494

9595
static cl::opt<bool>
@@ -2462,6 +2462,140 @@ DeleteDeadIFuncs(Module &M,
24622462
return Changed;
24632463
}
24642464

2465+
// Follows the use-def chain of \p V backwards until it finds a Function,
2466+
// in which case it collects in \p Versions.
2467+
static void collectVersions(Value *V, SmallVectorImpl<Function *> &Versions) {
2468+
if (auto *F = dyn_cast<Function>(V)) {
2469+
Versions.push_back(F);
2470+
} else if (auto *Sel = dyn_cast<SelectInst>(V)) {
2471+
collectVersions(Sel->getTrueValue(), Versions);
2472+
collectVersions(Sel->getFalseValue(), Versions);
2473+
} else if (auto *Phi = dyn_cast<PHINode>(V)) {
2474+
for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I)
2475+
collectVersions(Phi->getIncomingValue(I), Versions);
2476+
}
2477+
}
2478+
2479+
// Bypass the IFunc Resolver of MultiVersioned functions when possible. To
2480+
// deduce whether the optimization is legal we need to compare the target
2481+
// features between caller and callee versions. The criteria for bypassing
2482+
// the resolver are the following:
2483+
//
2484+
// * If the callee's feature set is a subset of the caller's feature set,
2485+
// then the callee is a candidate for direct call.
2486+
//
2487+
// * Among such candidates the one of highest priority is the best match
2488+
// and it shall be picked, unless there is a version of the callee with
2489+
// higher priority than the best match which cannot be picked from a
2490+
// higher priority caller (directly or through the resolver).
2491+
//
2492+
// * For every higher priority callee version than the best match, there
2493+
// is a higher priority caller version whose feature set availability
2494+
// is implied by the callee's feature set.
2495+
//
2496+
static bool OptimizeNonTrivialIFuncs(
2497+
Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI) {
2498+
bool Changed = false;
2499+
2500+
// Cache containing the mask constructed from a function's target features.
2501+
DenseMap<Function *, uint64_t> FeaturePriorityMap;
2502+
2503+
for (GlobalIFunc &IF : M.ifuncs()) {
2504+
if (IF.isInterposable())
2505+
continue;
2506+
2507+
Function *Resolver = IF.getResolverFunction();
2508+
if (!Resolver)
2509+
continue;
2510+
2511+
if (Resolver->isInterposable())
2512+
continue;
2513+
2514+
TargetTransformInfo &TTI = GetTTI(*Resolver);
2515+
if (!TTI.hasFMV())
2516+
return false;
2517+
2518+
// Discover the callee versions.
2519+
SmallVector<Function *> Callees;
2520+
for (BasicBlock &BB : *Resolver)
2521+
if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator()))
2522+
collectVersions(Ret->getReturnValue(), Callees);
2523+
2524+
if (Callees.empty())
2525+
continue;
2526+
2527+
// Cache the feature mask for each callee.
2528+
for (Function *Callee : Callees) {
2529+
auto [It, Inserted] = FeaturePriorityMap.try_emplace(Callee);
2530+
if (Inserted)
2531+
It->second = TTI.getFMVPriority(*Callee);
2532+
}
2533+
2534+
// Sort the callee versions in decreasing priority order.
2535+
sort(Callees, [&](auto *LHS, auto *RHS) {
2536+
return FeaturePriorityMap[LHS] > FeaturePriorityMap[RHS];
2537+
});
2538+
2539+
// Find the callsites and cache the feature mask for each caller.
2540+
SmallVector<Function *> Callers;
2541+
DenseMap<Function *, SmallVector<CallBase *>> CallSiteMap;
2542+
for (User *U : IF.users()) {
2543+
if (auto *CB = dyn_cast<CallBase>(U)) {
2544+
if (CB->getCalledOperand() == &IF) {
2545+
Function *Caller = CB->getFunction();
2546+
auto [FeatIt, FeatInserted] = FeaturePriorityMap.try_emplace(Caller);
2547+
if (FeatInserted)
2548+
FeatIt->second = TTI.getFMVPriority(*Caller);
2549+
auto [CallIt, CallInserted] = CallSiteMap.try_emplace(Caller);
2550+
if (CallInserted)
2551+
Callers.push_back(Caller);
2552+
CallIt->second.push_back(CB);
2553+
}
2554+
}
2555+
}
2556+
2557+
// Sort the caller versions in decreasing priority order.
2558+
sort(Callers, [&](auto *LHS, auto *RHS) {
2559+
return FeaturePriorityMap[LHS] > FeaturePriorityMap[RHS];
2560+
});
2561+
2562+
// Index to the highest priority candidate.
2563+
unsigned I = 0;
2564+
// Now try to redirect calls starting from higher priority callers.
2565+
for (Function *Caller : Callers) {
2566+
// Getting here means we found callers of equal priority.
2567+
if (I == Callees.size())
2568+
break;
2569+
Function *Callee = Callees[I];
2570+
uint64_t CallerPriority = FeaturePriorityMap[Caller];
2571+
uint64_t CalleePriority = FeaturePriorityMap[Callee];
2572+
// If the priority of the caller is greater or equal to the highest
2573+
// priority candidate then it shall be picked. In case of equality
2574+
// advance the candidate index one position.
2575+
if (CallerPriority == CalleePriority)
2576+
++I;
2577+
else if (CallerPriority < CalleePriority) {
2578+
// Keep advancing the candidate index as long as the caller's
2579+
// features are a subset of the current candidate's.
2580+
while ((CallerPriority & CalleePriority) == CallerPriority) {
2581+
if (++I == Callees.size())
2582+
break;
2583+
CalleePriority = FeaturePriorityMap[Callees[I]];
2584+
}
2585+
continue;
2586+
}
2587+
auto &CallSites = CallSiteMap[Caller];
2588+
for (CallBase *CS : CallSites)
2589+
CS->setCalledOperand(Callee);
2590+
Changed = true;
2591+
}
2592+
if (IF.use_empty() ||
2593+
all_of(IF.users(), [](User *U) { return isa<GlobalAlias>(U); }))
2594+
NumIFuncsResolved++;
2595+
}
2596+
return Changed;
2597+
}
2598+
24652599
static bool
24662600
optimizeGlobalsInModule(Module &M, const DataLayout &DL,
24672601
function_ref<TargetLibraryInfo &(Function &)> GetTLI,
@@ -2525,6 +2659,9 @@ optimizeGlobalsInModule(Module &M, const DataLayout &DL,
25252659
// Optimize IFuncs whose callee's are statically known.
25262660
LocalChange |= OptimizeStaticIFuncs(M);
25272661

2662+
// Optimize IFuncs based on the target features of the caller.
2663+
LocalChange |= OptimizeNonTrivialIFuncs(M, GetTTI);
2664+
25282665
// Remove any IFuncs that are now dead.
25292666
LocalChange |= DeleteDeadIFuncs(M, NotDiscardableComdats);
25302667

0 commit comments

Comments
 (0)