diff --git a/modules/compiler/test/group_ops.cpp b/modules/compiler/test/group_ops.cpp index c6db8b8ab..0dfbac4c5 100644 --- a/modules/compiler/test/group_ops.cpp +++ b/modules/compiler/test/group_ops.cpp @@ -418,8 +418,7 @@ define void @test_wrapper(i32 %i, float %f, i32 %sg_lid, i64 %lid_x, i64 %lid_y, auto Info = BI.isMuxGroupCollective(Builtin.ID); ASSERT_TRUE(Info) << InfoStr; - // Now check that the returned values are what we expect. We don't - // check 'type' or 'function' here as it's not set by either party. + // Now check that the returned values are what we expect. assert(Info && "Asserting the optional to silence a compiler warning"); EXPECT_EQ(Info->Op, GroupOps[GroupOpIdx].Collective.Op) << InfoStr; EXPECT_EQ(Info->Scope, GroupOps[GroupOpIdx].Collective.Scope) << InfoStr; @@ -428,6 +427,11 @@ define void @test_wrapper(i32 %i, float %f, i32 %sg_lid, i64 %lid_x, i64 %lid_y, EXPECT_EQ(Info->Recurrence, GroupOps[GroupOpIdx].Collective.Recurrence) << InfoStr; + EXPECT_EQ(Builtin.ID, BI.getMuxGroupCollective(*Info)) << InfoStr; + EXPECT_EQ(Builtin.ID, + BI.getMuxGroupCollective(GroupOps[GroupOpIdx].Collective)) + << InfoStr; + ++GroupOpIdx; } } diff --git a/modules/compiler/utils/include/compiler/utils/builtin_info.h b/modules/compiler/utils/include/compiler/utils/builtin_info.h index b321996b0..07918c1f5 100644 --- a/modules/compiler/utils/include/compiler/utils/builtin_info.h +++ b/modules/compiler/utils/include/compiler/utils/builtin_info.h @@ -441,6 +441,10 @@ class BuiltinInfo { /// @brief Gets information about a mux group operation builtin static std::optional isMuxGroupCollective(BuiltinID ID); + /// @brief Returns the mux builtin ID matching the group collective, or + /// eBuiltinInvalid. + static BuiltinID getMuxGroupCollective(const GroupCollective &Group); + /// @brief Returns true if the mux builtin has a barrier ID as its first /// operand. static bool isMuxBuiltinWithBarrierID(BuiltinID ID) { diff --git a/modules/compiler/utils/source/builtin_info.cpp b/modules/compiler/utils/source/builtin_info.cpp index 22489f414..143352cf1 100644 --- a/modules/compiler/utils/source/builtin_info.cpp +++ b/modules/compiler/utils/source/builtin_info.cpp @@ -247,6 +247,7 @@ std::pair> BuiltinInfo::identifyMuxBuiltin( } return {ID, OverloadInfo}; +#undef SCOPED_GROUP_OP } BuiltinUniformity BuiltinInfo::isBuiltinUniform(Builtin const &B, @@ -875,6 +876,7 @@ std::string BuiltinInfo::getMuxBuiltinName(BuiltinID ID, return BaseName + "_" + getMangledTypeStr(Ty); } llvm_unreachable("Unhandled mux builtin"); +#undef CASE_GROUP_OP_ALL_SCOPES } Function *BuiltinInfo::defineMuxBuiltin(BuiltinID ID, Module &M, @@ -1093,6 +1095,88 @@ std::optional BuiltinInfo::isMuxGroupCollective(BuiltinID ID) { } return Collective; +#undef CASE_GROUP_OP_ALL_SCOPES +} + +BuiltinID BuiltinInfo::getMuxGroupCollective(const GroupCollective &Group) { +#define SIMPLE_SCOPE_SWITCH(OP) \ + do { \ + switch (Group.Scope) { \ + default: \ + llvm_unreachable("Impossible scope kind"); \ + case GroupCollective::ScopeKind::SubGroup: \ + return eMuxBuiltinSubgroup##OP; \ + case GroupCollective::ScopeKind::WorkGroup: \ + return eMuxBuiltinWorkgroup##OP; \ + case GroupCollective::ScopeKind::VectorGroup: \ + return eMuxBuiltinVecgroup##OP; \ + } \ + } while (0) + +#define COMPLEX_SCOPE_SWITCH(OP, SUFFIX) \ + do { \ + switch (Group.Recurrence) { \ + default: \ + llvm_unreachable("Unhandled recursion kind"); \ + case RecurKind::Add: \ + SIMPLE_SCOPE_SWITCH(OP##Add##SUFFIX); \ + case RecurKind::Mul: \ + SIMPLE_SCOPE_SWITCH(OP##Mul##SUFFIX); \ + case RecurKind::FAdd: \ + SIMPLE_SCOPE_SWITCH(OP##FAdd##SUFFIX); \ + case RecurKind::FMul: \ + SIMPLE_SCOPE_SWITCH(OP##FMul##SUFFIX); \ + case RecurKind::SMin: \ + SIMPLE_SCOPE_SWITCH(OP##SMin##SUFFIX); \ + case RecurKind::UMin: \ + SIMPLE_SCOPE_SWITCH(OP##UMin##SUFFIX); \ + case RecurKind::FMin: \ + SIMPLE_SCOPE_SWITCH(OP##FMin##SUFFIX); \ + case RecurKind::SMax: \ + SIMPLE_SCOPE_SWITCH(OP##SMax##SUFFIX); \ + case RecurKind::UMax: \ + SIMPLE_SCOPE_SWITCH(OP##UMax##SUFFIX); \ + case RecurKind::FMax: \ + SIMPLE_SCOPE_SWITCH(OP##FMax##SUFFIX); \ + case RecurKind::And: \ + if (Group.IsLogical) { \ + SIMPLE_SCOPE_SWITCH(OP##LogicalAnd##SUFFIX); \ + } else { \ + SIMPLE_SCOPE_SWITCH(OP##And##SUFFIX); \ + } \ + case RecurKind::Or: \ + if (Group.IsLogical) { \ + SIMPLE_SCOPE_SWITCH(OP##LogicalOr##SUFFIX); \ + } else { \ + SIMPLE_SCOPE_SWITCH(OP##Or##SUFFIX); \ + } \ + case RecurKind::Xor: \ + if (Group.IsLogical) { \ + SIMPLE_SCOPE_SWITCH(OP##LogicalXor##SUFFIX); \ + } else { \ + SIMPLE_SCOPE_SWITCH(OP##Xor##SUFFIX); \ + } \ + } \ + } while (0) + + switch (Group.Op) { + case GroupCollective::OpKind::All: + SIMPLE_SCOPE_SWITCH(All); + case GroupCollective::OpKind::Any: + SIMPLE_SCOPE_SWITCH(Any); + case GroupCollective::OpKind::Broadcast: + SIMPLE_SCOPE_SWITCH(Broadcast); + case GroupCollective::OpKind::Reduction: + COMPLEX_SCOPE_SWITCH(Reduce, ); + case GroupCollective::OpKind::ScanExclusive: + COMPLEX_SCOPE_SWITCH(Scan, Exclusive); + case GroupCollective::OpKind::ScanInclusive: + COMPLEX_SCOPE_SWITCH(Scan, Inclusive); + break; + } + return eBuiltinInvalid; +#undef COMPLEX_SCOPE_SWITCH +#undef SCOPE_SWITCH } bool BuiltinInfo::isOverloadableMuxBuiltinID(BuiltinID ID) {