Skip to content

Commit

Permalink
Merge pull request #112 from frasercrmck/builtin-info-group-collectives
Browse files Browse the repository at this point in the history
[compiler] Provide a mapping from GroupCollective -> BuiltinID
  • Loading branch information
frasercrmck authored Aug 30, 2023
2 parents 678ee7d + a8fd772 commit 1a1d826
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 2 deletions.
8 changes: 6 additions & 2 deletions modules/compiler/test/group_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,7 @@ define void @test_wrapper(i32 %i, float %f, i32 %sg_lid, i64 %lid_x, i64 %lid_y,
auto Info = BI.isMuxGroupCollective(Builtin.ID);
ASSERT_TRUE(Info) << InfoStr;

// Now check that the returned values are what we expect. We don't
// check 'type' or 'function' here as it's not set by either party.
// Now check that the returned values are what we expect.
assert(Info && "Asserting the optional to silence a compiler warning");
EXPECT_EQ(Info->Op, GroupOps[GroupOpIdx].Collective.Op) << InfoStr;
EXPECT_EQ(Info->Scope, GroupOps[GroupOpIdx].Collective.Scope) << InfoStr;
Expand All @@ -428,6 +427,11 @@ define void @test_wrapper(i32 %i, float %f, i32 %sg_lid, i64 %lid_x, i64 %lid_y,
EXPECT_EQ(Info->Recurrence, GroupOps[GroupOpIdx].Collective.Recurrence)
<< InfoStr;

EXPECT_EQ(Builtin.ID, BI.getMuxGroupCollective(*Info)) << InfoStr;
EXPECT_EQ(Builtin.ID,
BI.getMuxGroupCollective(GroupOps[GroupOpIdx].Collective))
<< InfoStr;

++GroupOpIdx;
}
}
Expand Down
4 changes: 4 additions & 0 deletions modules/compiler/utils/include/compiler/utils/builtin_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,10 @@ class BuiltinInfo {
/// @brief Gets information about a mux group operation builtin
static std::optional<GroupCollective> isMuxGroupCollective(BuiltinID ID);

/// @brief Returns the mux builtin ID matching the group collective, or
/// eBuiltinInvalid.
static BuiltinID getMuxGroupCollective(const GroupCollective &Group);

/// @brief Returns true if the mux builtin has a barrier ID as its first
/// operand.
static bool isMuxBuiltinWithBarrierID(BuiltinID ID) {
Expand Down
84 changes: 84 additions & 0 deletions modules/compiler/utils/source/builtin_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ std::pair<BuiltinID, std::vector<Type *>> BuiltinInfo::identifyMuxBuiltin(
}

return {ID, OverloadInfo};
#undef SCOPED_GROUP_OP
}

BuiltinUniformity BuiltinInfo::isBuiltinUniform(Builtin const &B,
Expand Down Expand Up @@ -875,6 +876,7 @@ std::string BuiltinInfo::getMuxBuiltinName(BuiltinID ID,
return BaseName + "_" + getMangledTypeStr(Ty);
}
llvm_unreachable("Unhandled mux builtin");
#undef CASE_GROUP_OP_ALL_SCOPES
}

Function *BuiltinInfo::defineMuxBuiltin(BuiltinID ID, Module &M,
Expand Down Expand Up @@ -1093,6 +1095,88 @@ std::optional<GroupCollective> BuiltinInfo::isMuxGroupCollective(BuiltinID ID) {
}

return Collective;
#undef CASE_GROUP_OP_ALL_SCOPES
}

BuiltinID BuiltinInfo::getMuxGroupCollective(const GroupCollective &Group) {
#define SIMPLE_SCOPE_SWITCH(OP) \
do { \
switch (Group.Scope) { \
default: \
llvm_unreachable("Impossible scope kind"); \
case GroupCollective::ScopeKind::SubGroup: \
return eMuxBuiltinSubgroup##OP; \
case GroupCollective::ScopeKind::WorkGroup: \
return eMuxBuiltinWorkgroup##OP; \
case GroupCollective::ScopeKind::VectorGroup: \
return eMuxBuiltinVecgroup##OP; \
} \
} while (0)

#define COMPLEX_SCOPE_SWITCH(OP, SUFFIX) \
do { \
switch (Group.Recurrence) { \
default: \
llvm_unreachable("Unhandled recursion kind"); \
case RecurKind::Add: \
SIMPLE_SCOPE_SWITCH(OP##Add##SUFFIX); \
case RecurKind::Mul: \
SIMPLE_SCOPE_SWITCH(OP##Mul##SUFFIX); \
case RecurKind::FAdd: \
SIMPLE_SCOPE_SWITCH(OP##FAdd##SUFFIX); \
case RecurKind::FMul: \
SIMPLE_SCOPE_SWITCH(OP##FMul##SUFFIX); \
case RecurKind::SMin: \
SIMPLE_SCOPE_SWITCH(OP##SMin##SUFFIX); \
case RecurKind::UMin: \
SIMPLE_SCOPE_SWITCH(OP##UMin##SUFFIX); \
case RecurKind::FMin: \
SIMPLE_SCOPE_SWITCH(OP##FMin##SUFFIX); \
case RecurKind::SMax: \
SIMPLE_SCOPE_SWITCH(OP##SMax##SUFFIX); \
case RecurKind::UMax: \
SIMPLE_SCOPE_SWITCH(OP##UMax##SUFFIX); \
case RecurKind::FMax: \
SIMPLE_SCOPE_SWITCH(OP##FMax##SUFFIX); \
case RecurKind::And: \
if (Group.IsLogical) { \
SIMPLE_SCOPE_SWITCH(OP##LogicalAnd##SUFFIX); \
} else { \
SIMPLE_SCOPE_SWITCH(OP##And##SUFFIX); \
} \
case RecurKind::Or: \
if (Group.IsLogical) { \
SIMPLE_SCOPE_SWITCH(OP##LogicalOr##SUFFIX); \
} else { \
SIMPLE_SCOPE_SWITCH(OP##Or##SUFFIX); \
} \
case RecurKind::Xor: \
if (Group.IsLogical) { \
SIMPLE_SCOPE_SWITCH(OP##LogicalXor##SUFFIX); \
} else { \
SIMPLE_SCOPE_SWITCH(OP##Xor##SUFFIX); \
} \
} \
} while (0)

switch (Group.Op) {
case GroupCollective::OpKind::All:
SIMPLE_SCOPE_SWITCH(All);
case GroupCollective::OpKind::Any:
SIMPLE_SCOPE_SWITCH(Any);
case GroupCollective::OpKind::Broadcast:
SIMPLE_SCOPE_SWITCH(Broadcast);
case GroupCollective::OpKind::Reduction:
COMPLEX_SCOPE_SWITCH(Reduce, );
case GroupCollective::OpKind::ScanExclusive:
COMPLEX_SCOPE_SWITCH(Scan, Exclusive);
case GroupCollective::OpKind::ScanInclusive:
COMPLEX_SCOPE_SWITCH(Scan, Inclusive);
break;
}
return eBuiltinInvalid;
#undef COMPLEX_SCOPE_SWITCH
#undef SCOPE_SWITCH
}

bool BuiltinInfo::isOverloadableMuxBuiltinID(BuiltinID ID) {
Expand Down

0 comments on commit 1a1d826

Please sign in to comment.