Skip to content

Commit

Permalink
Merge pull request #86 from frasercrmck/mux-group-builtins-further
Browse files Browse the repository at this point in the history
[compiler] Further integrate mux group builtins
  • Loading branch information
frasercrmck authored Aug 14, 2023
2 parents 8ad8e61 + 447d97b commit 5b0365a
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 114 deletions.
2 changes: 1 addition & 1 deletion modules/compiler/test/group_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class GroupOpsTest : public CompilerLLVMModuleTest {
std::string getLLVMFnString(StringRef ParamName = "%x") const {
std::string FnStr =
LLVMTy + " @" + MangledFnName + "(" + LLVMTy + " " + ParamName.str();
if (Collective.Op == GroupCollective::OpKind::Broadcast) {
if (Collective.isBroadcast()) {
if (Collective.isSubGroupScope()) {
FnStr += ", i32 %sg_lid";
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ struct GroupCollective {
bool isScan() const {
return Op == OpKind::ScanExclusive || Op == OpKind::ScanInclusive;
}
/// @brief Returns true for reduction collective operations.
bool isReduction() const { return Op == OpKind::Reduction; }
/// @brief Returns true for broadcast collective operations.
bool isBroadcast() const { return Op == OpKind::Broadcast; }
/// @brief Returns true for sub-group collective operations.
bool isSubGroupScope() const { return Scope == ScopeKind::SubGroup; }
/// @brief Returns true for work-group collective operations.
Expand Down
26 changes: 7 additions & 19 deletions modules/compiler/utils/source/builtin_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,27 +275,15 @@ BuiltinUniformity BuiltinInfo::isBuiltinUniform(Builtin const &B,
// not support vectorizing along y or z (see CA-2843).
return SimdDimIdx ? eBuiltinUniformityNever
: eBuiltinUniformityInstanceID;
case eMuxBuiltinSubgroupAll:
case eMuxBuiltinSubgroupAny:
case eMuxBuiltinSubgroupBroadcast:
case eMuxBuiltinSubgroupReduceAdd:
case eMuxBuiltinSubgroupReduceFAdd:
case eMuxBuiltinSubgroupReduceMul:
case eMuxBuiltinSubgroupReduceFMul:
case eMuxBuiltinSubgroupReduceSMax:
case eMuxBuiltinSubgroupReduceUMax:
case eMuxBuiltinSubgroupReduceFMax:
case eMuxBuiltinSubgroupReduceSMin:
case eMuxBuiltinSubgroupReduceUMin:
case eMuxBuiltinSubgroupReduceFMin:
case eMuxBuiltinSubgroupReduceAnd:
case eMuxBuiltinSubgroupReduceOr:
case eMuxBuiltinSubgroupReduceXor:
case eMuxBuiltinSubgroupReduceLogicalAnd:
case eMuxBuiltinSubgroupReduceLogicalOr:
case eMuxBuiltinSubgroupReduceLogicalXor:
}

// Reductions and broadcasts are always uniform
if (auto Info = isMuxGroupCollective(B.ID)) {
if (Info->isAnyAll() || Info->isReduction() || Info->isBroadcast()) {
return eBuiltinUniformityAlways;
}
}

if (LangImpl) {
return LangImpl->isBuiltinUniform(B, CI, SimdDimIdx);
}
Expand Down
5 changes: 0 additions & 5 deletions modules/compiler/utils/source/cl_builtin_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,6 @@ Builtin CLBuiltinInfo::analyzeBuiltin(Function const &Callee) const {

bool IsConvergent = false;
unsigned Properties = eBuiltinPropertyNone;
llvm::SmallVector<llvm::Type *, 2> OverloadInfo;
switch (ID) {
default:
// Assume convergence on unknown builtins.
Expand Down Expand Up @@ -1278,10 +1277,6 @@ Builtin CLBuiltinInfo::analyzeBuiltin(Function const &Callee) const {
case eCLBuiltinWorkgroupScanLogicalXorExclusive:
IsConvergent = true;
Properties |= eBuiltinPropertyMapToMuxGroupBuiltin;
if (ID != eCLBuiltinWorkgroupAll && ID != eCLBuiltinWorkgroupAny &&
ID != eCLBuiltinSubgroupAll && ID != eCLBuiltinSubgroupAny) {
OverloadInfo.push_back(Callee.getArg(0)->getType());
}
break;
}

Expand Down
204 changes: 145 additions & 59 deletions modules/compiler/utils/source/degenerate_sub_group_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include <compiler/utils/degenerate_sub_group_pass.h>
#include <compiler/utils/device_info.h>
#include <compiler/utils/group_collective_helpers.h>
#include <compiler/utils/mangling.h>
#include <compiler/utils/metadata.h>
#include <compiler/utils/pass_functions.h>
#include <llvm/ADT/SmallVector.h>
Expand All @@ -37,7 +36,6 @@
#include <llvm/Support/ErrorHandling.h>
#include <llvm/Transforms/Utils/Cloning.h>
#include <llvm/Transforms/Utils/ValueMapper.h>
#include <multi_llvm/optional_helper.h>

#include <set>

Expand All @@ -50,54 +48,144 @@ namespace {
/// @param[in] CI Call instruction to query.
///
/// @return True if CI is a call to sub-group builtin, false otherwise.
bool isSubGroupFunction(CallInst *CI, compiler::utils::BuiltinInfo &BI) {
std::optional<compiler::utils::Builtin> isSubGroupFunction(
CallInst *CI, compiler::utils::BuiltinInfo &BI) {
auto *Fcn = CI->getCalledFunction();
assert(Fcn && "virtual calls are not supported");
auto Builtin = BI.analyzeBuiltin(*Fcn);
if (auto GroupOp = BI.isMuxGroupCollective(Builtin.ID)) {
return GroupOp->isSubGroupScope();
auto SGBuiltin = BI.analyzeBuiltin(*Fcn);

if (SGBuiltin.ID == compiler::utils::eMuxBuiltinSubGroupBarrier) {
return SGBuiltin;
}
if (auto GroupOp = BI.isMuxGroupCollective(SGBuiltin.ID);
GroupOp && GroupOp->isSubGroupScope()) {
return SGBuiltin;
}

return Builtin.ID == compiler::utils::eMuxBuiltinSubGroupBarrier;
return std::nullopt;
}

/// @brief Helper for building the symbol name of the mangled work-group builtin
/// corresponding to the sub-group builtin.
///
/// @return The mangled work-group builtin corresponding to `SubgroupBuiltin`.
Function *lookupWGBuiltin(Function &SubgroupF, Module &M) {
// We must handle the case where we're replacing a __mux_sub_group_barrier
// with a __mux_work_group_barrier. Our 'demangleName' API works differently
// with non-mangled builtin names and returns an empty string. Just work
// around it specifically.
auto SubgroupFName = SubgroupF.getName();
auto SubgroupFTy = SubgroupF.getFunctionType();
if (SubgroupFName == compiler::utils::MuxBuiltins::sub_group_barrier) {
auto *WorkgroupF =
M.getOrInsertFunction(compiler::utils::MuxBuiltins::work_group_barrier,
SubgroupFTy)
.getCallee();
return cast<Function>(WorkgroupF);
}

std::string WorkgroupFName = "__mux_work" + SubgroupFName.substr(9).str();

// Work-group builtins have an extra 'barrier id' as the first parameter.
SmallVector<Type *, 4> ArgTys;
ArgTys.push_back(IntegerType::getInt32Ty(M.getContext()));
ArgTys.push_back(SubgroupF.getArg(0)->getType());
// Work-group broadcasts unconditionally have three work-item ID parameters
// (x, y, z), whereas sub-group broadcasts just have one (sub-group ID).
if (SubgroupFName.contains("broadcast")) {
for (unsigned i = 0; i < 3; i++) {
ArgTys.push_back(compiler::utils::getSizeType(M));
/// @return The work-group equivalent of the given builtin.
Function *lookupWGBuiltin(const compiler::utils::Builtin &SGBuiltin,
compiler::utils::BuiltinInfo &BI, Module &M) {
compiler::utils::BuiltinID WGBuiltinID = [](compiler::utils::BuiltinID ID) {
switch (ID) {
default:
return compiler::utils::eBuiltinInvalid;
case compiler::utils::eMuxBuiltinSubGroupBarrier:
return compiler::utils::eMuxBuiltinWorkGroupBarrier;
case compiler::utils::eMuxBuiltinSubgroupAny:
return compiler::utils::eMuxBuiltinWorkgroupAny;
case compiler::utils::eMuxBuiltinSubgroupAll:
return compiler::utils::eMuxBuiltinWorkgroupAll;
case compiler::utils::eMuxBuiltinSubgroupBroadcast:
return compiler::utils::eMuxBuiltinWorkgroupBroadcast;
case compiler::utils::eMuxBuiltinSubgroupReduceAdd:
return compiler::utils::eMuxBuiltinWorkgroupReduceAdd;
case compiler::utils::eMuxBuiltinSubgroupReduceFAdd:
return compiler::utils::eMuxBuiltinWorkgroupReduceFAdd;
case compiler::utils::eMuxBuiltinSubgroupReduceMul:
return compiler::utils::eMuxBuiltinWorkgroupReduceMul;
case compiler::utils::eMuxBuiltinSubgroupReduceFMul:
return compiler::utils::eMuxBuiltinWorkgroupReduceFMul;
case compiler::utils::eMuxBuiltinSubgroupReduceUMax:
return compiler::utils::eMuxBuiltinWorkgroupReduceUMax;
case compiler::utils::eMuxBuiltinSubgroupReduceSMax:
return compiler::utils::eMuxBuiltinWorkgroupReduceSMax;
case compiler::utils::eMuxBuiltinSubgroupReduceFMax:
return compiler::utils::eMuxBuiltinWorkgroupReduceFMax;
case compiler::utils::eMuxBuiltinSubgroupReduceUMin:
return compiler::utils::eMuxBuiltinWorkgroupReduceUMin;
case compiler::utils::eMuxBuiltinSubgroupReduceSMin:
return compiler::utils::eMuxBuiltinWorkgroupReduceSMin;
case compiler::utils::eMuxBuiltinSubgroupReduceFMin:
return compiler::utils::eMuxBuiltinWorkgroupReduceFMin;
case compiler::utils::eMuxBuiltinSubgroupReduceAnd:
return compiler::utils::eMuxBuiltinWorkgroupReduceAnd;
case compiler::utils::eMuxBuiltinSubgroupReduceOr:
return compiler::utils::eMuxBuiltinWorkgroupReduceOr;
case compiler::utils::eMuxBuiltinSubgroupReduceXor:
return compiler::utils::eMuxBuiltinWorkgroupReduceXor;
case compiler::utils::eMuxBuiltinSubgroupReduceLogicalAnd:
return compiler::utils::eMuxBuiltinWorkgroupReduceLogicalAnd;
case compiler::utils::eMuxBuiltinSubgroupReduceLogicalOr:
return compiler::utils::eMuxBuiltinWorkgroupReduceLogicalOr;
case compiler::utils::eMuxBuiltinSubgroupReduceLogicalXor:
return compiler::utils::eMuxBuiltinWorkgroupReduceLogicalXor;
case compiler::utils::eMuxBuiltinSubgroupScanAddInclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanAddInclusive;
case compiler::utils::eMuxBuiltinSubgroupScanFAddInclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanFAddInclusive;
case compiler::utils::eMuxBuiltinSubgroupScanMulInclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanMulInclusive;
case compiler::utils::eMuxBuiltinSubgroupScanFMulInclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanFMulInclusive;
case compiler::utils::eMuxBuiltinSubgroupScanUMaxInclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanUMaxInclusive;
case compiler::utils::eMuxBuiltinSubgroupScanSMaxInclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanSMaxInclusive;
case compiler::utils::eMuxBuiltinSubgroupScanFMaxInclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanFMaxInclusive;
case compiler::utils::eMuxBuiltinSubgroupScanUMinInclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanUMinInclusive;
case compiler::utils::eMuxBuiltinSubgroupScanSMinInclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanSMinInclusive;
case compiler::utils::eMuxBuiltinSubgroupScanFMinInclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanFMinInclusive;
case compiler::utils::eMuxBuiltinSubgroupScanAndInclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanAndInclusive;
case compiler::utils::eMuxBuiltinSubgroupScanOrInclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanOrInclusive;
case compiler::utils::eMuxBuiltinSubgroupScanXorInclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanXorInclusive;
case compiler::utils::eMuxBuiltinSubgroupScanLogicalAndInclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanLogicalAndInclusive;
case compiler::utils::eMuxBuiltinSubgroupScanLogicalOrInclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanLogicalOrInclusive;
case compiler::utils::eMuxBuiltinSubgroupScanLogicalXorInclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanLogicalXorInclusive;
case compiler::utils::eMuxBuiltinSubgroupScanAddExclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanAddExclusive;
case compiler::utils::eMuxBuiltinSubgroupScanFAddExclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanFAddExclusive;
case compiler::utils::eMuxBuiltinSubgroupScanMulExclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanMulExclusive;
case compiler::utils::eMuxBuiltinSubgroupScanFMulExclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanFMulExclusive;
case compiler::utils::eMuxBuiltinSubgroupScanUMaxExclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanUMaxExclusive;
case compiler::utils::eMuxBuiltinSubgroupScanSMaxExclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanSMaxExclusive;
case compiler::utils::eMuxBuiltinSubgroupScanFMaxExclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanFMaxExclusive;
case compiler::utils::eMuxBuiltinSubgroupScanUMinExclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanUMinExclusive;
case compiler::utils::eMuxBuiltinSubgroupScanSMinExclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanSMinExclusive;
case compiler::utils::eMuxBuiltinSubgroupScanFMinExclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanFMinExclusive;
case compiler::utils::eMuxBuiltinSubgroupScanAndExclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanAndExclusive;
case compiler::utils::eMuxBuiltinSubgroupScanOrExclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanOrExclusive;
case compiler::utils::eMuxBuiltinSubgroupScanXorExclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanXorExclusive;
case compiler::utils::eMuxBuiltinSubgroupScanLogicalAndExclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanLogicalAndExclusive;
case compiler::utils::eMuxBuiltinSubgroupScanLogicalOrExclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanLogicalOrExclusive;
case compiler::utils::eMuxBuiltinSubgroupScanLogicalXorExclusive:
return compiler::utils::eMuxBuiltinWorkgroupScanLogicalXorExclusive;
}
}
auto *WorkgroupFTy = FunctionType::get(SubgroupFTy->getReturnType(), ArgTys,
SubgroupFTy->isVarArg());
auto *WorkgroupF =
M.getOrInsertFunction(WorkgroupFName, WorkgroupFTy).getCallee();
return cast<Function>(WorkgroupF);
}(SGBuiltin.ID);

assert(WGBuiltinID != compiler::utils::eBuiltinInvalid &&
"Missing sub-group -> work-group mapping");
auto *WGBuiltin =
BI.getOrDeclareMuxBuiltin(WGBuiltinID, M, SGBuiltin.mux_overload_info);
assert(WGBuiltin && "Missing work-group builtin");

return WGBuiltin;
}

/// @brief Helper for determining if a call instruction calls a sub-group
Expand All @@ -121,27 +209,25 @@ bool isSubGroupWorkItemFunction(CallInst *CI) {
///
/// @param[in] SubGroupBuiltinCalls Builtin calls to replace.
void replaceSubGroupBuiltinCalls(
const SmallVectorImpl<CallInst *> &SubGroupBuiltinCalls,
const SmallVectorImpl<std::pair<CallInst *, compiler::utils::Builtin>>
&SubGroupBuiltinCalls,
compiler::utils::BuiltinInfo &BI) {
for (auto *I : SubGroupBuiltinCalls) {
auto *const SubGroupBuiltin = I->getCalledFunction();
for (auto &[I, SGBuiltin] : SubGroupBuiltinCalls) {
auto *const M = I->getModule();
if (!SubGroupBuiltin->getName().contains("broadcast")) {
if (SGBuiltin.ID != compiler::utils::eMuxBuiltinSubgroupBroadcast) {
// We can just forward the argument directly to the
// work-group builtin for everything except broadcasts.
SmallVector<Value *, 4> Args;
if (SubGroupBuiltin->getName() !=
compiler::utils::MuxBuiltins::sub_group_barrier) {
if (SGBuiltin.ID != compiler::utils::eMuxBuiltinSubGroupBarrier) {
// Barrier ID
Args.push_back(
ConstantInt::get(IntegerType::get(M->getContext(), 32), 0));
}
for (auto &arg : I->args()) {
Args.push_back(arg);
}
auto *const WorkGroupBuiltinFcn = lookupWGBuiltin(*SubGroupBuiltin, *M);
WorkGroupBuiltinFcn->setCallingConv(SubGroupBuiltin->getCallingConv());
WorkGroupBuiltinFcn->setConvergent();
auto *const WorkGroupBuiltinFcn = lookupWGBuiltin(SGBuiltin, BI, *M);
WorkGroupBuiltinFcn->setCallingConv(I->getCallingConv());
auto *WGCI = CallInst::Create(WorkGroupBuiltinFcn, Args, "", I);
WGCI->setCallingConv(I->getCallingConv());
I->replaceAllUsesWith(WGCI);
Expand Down Expand Up @@ -187,9 +273,8 @@ void replaceSubGroupBuiltinCalls(
Builder.CreateMul(LocalSizeX, LocalSizeY), "z");

auto *const SizeType = compiler::utils::getSizeType(*M);
auto *const WorkGroupBroadcastFcn = lookupWGBuiltin(*SubGroupBuiltin, *M);
WorkGroupBroadcastFcn->setCallingConv(SubGroupBuiltin->getCallingConv());
WorkGroupBroadcastFcn->setNotConvergent();
auto *const WorkGroupBroadcastFcn = lookupWGBuiltin(SGBuiltin, BI, *M);
WorkGroupBroadcastFcn->setCallingConv(I->getCallingConv());
// Because sub_group_broadcast takes uint as its index argument but
// work_group_broadcast takes size_t we potentially need cast here to the
// native size_t.
Expand Down Expand Up @@ -442,16 +527,17 @@ PreservedAnalyses compiler::utils::DegenerateSubGroupPass::run(
// The cloned functions are used by the non-degenerate subgroup kernels, so
// that we can collect subgroup builtin calls first and replace them in their
// original homes.
SmallVector<CallInst *, 32> SubGroupFunctionCalls;
SmallVector<std::pair<CallInst *, compiler::utils::Builtin>, 32>
SubGroupFunctionCalls;
SmallVector<CallInst *, 32> SubGroupWorkItemFunctionCalls;
worklist.assign(degenerateKernels.begin(), degenerateKernels.end());
worklist.append(usedByDegenerate.begin(), usedByDegenerate.end());
for (auto *const F : worklist) {
for (auto &BB : *F) {
for (auto &I : BB) {
if (auto *CI = dyn_cast<CallInst>(&I)) {
if (isSubGroupFunction(CI, BI)) {
SubGroupFunctionCalls.push_back(CI);
if (auto SGBuiltin = isSubGroupFunction(CI, BI)) {
SubGroupFunctionCalls.push_back({CI, *SGBuiltin});
} else if (isSubGroupWorkItemFunction(CI)) {
SubGroupWorkItemFunctionCalls.push_back(CI);
}
Expand Down Expand Up @@ -545,7 +631,7 @@ PreservedAnalyses compiler::utils::DegenerateSubGroupPass::run(
replaceSubGroupWorkItemBuiltinCalls(SubGroupWorkItemFunctionCalls, BI);

// Remove the old instructions from the module.
for (auto *I : SubGroupFunctionCalls) {
for (auto &[I, _] : SubGroupFunctionCalls) {
I->eraseFromParent();
}

Expand Down
17 changes: 3 additions & 14 deletions modules/compiler/vecz/source/analysis/uniform_value_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,20 +122,9 @@ static bool isSubgroupBroadcastOrReduction(
if (!Callee) {
return false;
}
auto const Builtin = BI.analyzeBuiltin(*Callee);
if (auto Info = BI.isMuxGroupCollective(Builtin.ID);
Info && Info->isSubGroupScope()) {
switch (Info->Op) {
default:
return false;
case compiler::utils::GroupCollective::OpKind::Any:
case compiler::utils::GroupCollective::OpKind::All:
case compiler::utils::GroupCollective::OpKind::Reduction:
case compiler::utils::GroupCollective::OpKind::Broadcast:
return true;
}
}
return false;
auto Info = BI.isMuxGroupCollective(BI.analyzeBuiltin(*Callee).ID);
return Info && Info->isSubGroupScope() &&
(Info->isAnyAll() || Info->isReduction() || Info->isBroadcast());
}

void UniformValueResult::findVectorLeaves(
Expand Down
Loading

0 comments on commit 5b0365a

Please sign in to comment.