@@ -68,18 +68,23 @@ std::optional<compiler::utils::Builtin> isSubGroupFunction(
68
68
}
69
69
70
70
// / @return The work-group equivalent of the given builtin.
71
- Function * lookupWGBuiltin ( const compiler::utils::Builtin &SGBuiltin,
72
- compiler::utils::BuiltinInfo &BI, Module &M) {
73
- compiler::utils::BuiltinID WGBuiltinID = compiler::utils::eBuiltinInvalid;
71
+ compiler::utils::BuiltinID lookupWGBuiltinID (
72
+ const compiler::utils::Builtin &SGBuiltin,
73
+ compiler::utils::BuiltinInfo &BI) {
74
74
if (SGBuiltin.ID == compiler::utils::eMuxBuiltinSubGroupBarrier) {
75
- WGBuiltinID = compiler::utils::eMuxBuiltinWorkGroupBarrier;
76
- } else {
77
- auto SGCollective = BI.isMuxGroupCollective (SGBuiltin.ID );
78
- assert (SGCollective.has_value () && " Not a sub-group builtin" );
79
- auto WGCollective = *SGCollective;
80
- WGCollective.Scope = compiler::utils::GroupCollective::ScopeKind::WorkGroup;
81
- WGBuiltinID = BI.getMuxGroupCollective (WGCollective);
75
+ return compiler::utils::eMuxBuiltinWorkGroupBarrier;
82
76
}
77
+ auto SGCollective = BI.isMuxGroupCollective (SGBuiltin.ID );
78
+ assert (SGCollective.has_value () && " Not a sub-group builtin" );
79
+ auto WGCollective = *SGCollective;
80
+ WGCollective.Scope = compiler::utils::GroupCollective::ScopeKind::WorkGroup;
81
+ return BI.getMuxGroupCollective (WGCollective);
82
+ }
83
+
84
+ // / @return The work-group equivalent of the given builtin.
85
+ Function *lookupWGBuiltin (const compiler::utils::Builtin &SGBuiltin,
86
+ compiler::utils::BuiltinInfo &BI, Module &M) {
87
+ compiler::utils::BuiltinID WGBuiltinID = lookupWGBuiltinID (SGBuiltin, BI);
83
88
// Not all sub-group builtins have a work-group equivalent.
84
89
if (WGBuiltinID == compiler::utils::eBuiltinInvalid) {
85
90
return nullptr ;
@@ -268,6 +273,13 @@ PreservedAnalyses compiler::utils::DegenerateSubGroupPass::run(
268
273
if (isKernelEntryPt (F)) {
269
274
kernels.push_back (&F);
270
275
276
+ if (compiler::utils::getReqdSubgroupSize (F)) {
277
+ // If there's a user-specified required sub-group size, we don't need to
278
+ // clone this kernel. If vectorization fails to produce the right
279
+ // sub-group size, we'll fail compilation.
280
+ continue ;
281
+ }
282
+
271
283
auto const local_sizes = compiler::utils::getLocalSizeMetadata (F);
272
284
if (!local_sizes) {
273
285
// If we don't know the local size at compile time, we can't guarantee
@@ -337,7 +349,8 @@ PreservedAnalyses compiler::utils::DegenerateSubGroupPass::run(
337
349
if (usesSubgroups.insert (&F).second ) {
338
350
worklist.push_back (&F);
339
351
}
340
- if (SGBuiltin && !lookupWGBuiltin (*SGBuiltin, BI, M)) {
352
+ if (SGBuiltin && lookupWGBuiltinID (*SGBuiltin, BI) ==
353
+ compiler::utils::eBuiltinInvalid) {
341
354
poisonList.insert (&F);
342
355
}
343
356
}
0 commit comments