Skip to content

Commit 3a05571

Browse files
authored
Merge pull request #125 from frasercrmck/vecz-reqd-sub-group-size
[compiler] Vectorize to any required sub-group size
2 parents 142c2d4 + 18dc877 commit 3a05571

16 files changed

+243
-30
lines changed

modules/compiler/cookie/{{cookiecutter.target_name}}/source/{{cookiecutter.target_name}}_pass_machinery.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,11 @@ bool {{cookiecutter.target_name.capitalize()}}VeczPassOpts(
249249
vecz_mode == compiler::VectorizationMode::NEVER) {
250250
return false;
251251
}
252+
// Handle required sub-group sizes
253+
if (auto reqd_subgroup_vf = vecz::getReqdSubgroupSizeOpts(F)) {
254+
PassOpts.assign(1, *reqd_subgroup_vf);
255+
return true;
256+
}
252257
auto env_var_opts = processOptimizationOptions(/*env_debug_prefix*/ {});
253258
if (!env_var_opts.vecz_pass_opts.has_value()) {
254259
return false;

modules/compiler/riscv/source/riscv_pass_machinery.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,11 @@ bool riscvVeczPassOpts(llvm::Function &F, llvm::ModuleAnalysisManager &,
148148
vecz_mode == compiler::VectorizationMode::NEVER) {
149149
return false;
150150
}
151+
// Handle required sub-group sizes
152+
if (auto reqd_subgroup_vf = vecz::getReqdSubgroupSizeOpts(F)) {
153+
PassOpts.assign(1, *reqd_subgroup_vf);
154+
return true;
155+
}
151156
auto env_var_opts = RiscvPassMachinery::processOptimizationOptions(
152157
/*env_debug_prefix*/ {}, vecz_mode);
153158
if (env_var_opts.vecz_pass_opts.empty()) {

modules/compiler/targets/host/source/HostPassMachinery.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ bool hostVeczPassOpts(llvm::Function &F, llvm::ModuleAnalysisManager &MAM,
6767
if (!compiler::utils::isKernelEntryPt(F)) {
6868
return false;
6969
}
70+
// Handle required sub-group sizes
71+
if (auto reqd_subgroup_vf = vecz::getReqdSubgroupSizeOpts(F)) {
72+
Opts.assign(1, *reqd_subgroup_vf);
73+
return true;
74+
}
7075
const auto &DI =
7176
MAM.getResult<compiler::utils::DeviceInfoAnalysis>(*F.getParent());
7277
auto max_work_width = DI.max_work_width;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
; Copyright (C) Codeplay Software Limited
2+
;
3+
; Licensed under the Apache License, Version 2.0 (the "License") with LLVM
4+
; Exceptions; you may not use this file except in compliance with the License.
5+
; You may obtain a copy of the License at
6+
;
7+
; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt
8+
;
9+
; Unless required by applicable law or agreed to in writing, software
10+
; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
; License for the specific language governing permissions and limitations
13+
; under the License.
14+
;
15+
; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
16+
17+
; Let vecz pick the right vectorization factor for this kernel check that the
18+
; verification pass correctly notes we've satisifed the required sub-group
19+
; size.
20+
; RUN: env muxc --device "%riscv_device" \
21+
; RUN: --passes run-vecz,verify-reqd-sub-group-satisfied < %s \
22+
; RUN: | FileCheck %s
23+
24+
; CHECK-LABEL: define void @__vecz_v8_bar_sg8(ptr addrspace(1) %in, ptr addrspace(1) %out) #0 !intel_reqd_sub_group_size !0 !codeplay_ca_vecz.derived !{{[0-9]+}} {
25+
26+
define void @bar_sg8(ptr addrspace(1) %in, ptr addrspace(1) %out) #0 !intel_reqd_sub_group_size !0 {
27+
%id = call i64 @__mux_get_global_id(i32 0)
28+
%in.addr = getelementptr i32, ptr addrspace(1) %in, i64 %id
29+
%x = load i32, ptr addrspace(1) %in.addr
30+
%y = add i32 %x, 1
31+
%out.addr = getelementptr i32, ptr addrspace(1) %out, i64 %id
32+
store i32 %y, ptr addrspace(1) %out.addr
33+
ret void
34+
}
35+
36+
declare i64 @__mux_get_global_id(i32)
37+
38+
attributes #0 = { "mux-kernel"="entry-point" }
39+
40+
!0 = !{i32 8}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
; Copyright (C) Codeplay Software Limited
2+
;
3+
; Licensed under the Apache License, Version 2.0 (the "License") with LLVM
4+
; Exceptions; you may not use this file except in compliance with the License.
5+
; You may obtain a copy of the License at
6+
;
7+
; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt
8+
;
9+
; Unless required by applicable law or agreed to in writing, software
10+
; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
; License for the specific language governing permissions and limitations
13+
; under the License.
14+
;
15+
; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
16+
17+
; Try and forcibly vectorize this no-vecz kernel by 8 and check that the
18+
; vectorizer does not run, since the required sub-group size is 1. Then check
19+
; that the verification pass correctly picks up that we have satisfied the
20+
; kernel's required sub-group size by way of not vectorizing.
21+
; RUN: env CA_RISCV_VF=8 muxc --device "%riscv_device" \
22+
; RUN: --passes run-vecz,verify-reqd-sub-group-satisfied < %s \
23+
; RUN: | FileCheck %s
24+
25+
; CHECK-NOT: __vecz_
26+
define void @foo_sg1() #0 !intel_reqd_sub_group_size !2 {
27+
ret void
28+
}
29+
30+
attributes #0 = { "mux-kernel"="entry-point" }
31+
32+
!2 = !{i32 1}

modules/compiler/targets/riscv/test/lit/passes/verify-reqd-sg-size.ll

+8-11
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,27 @@
1414
;
1515
; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
1616

17-
; Forcibly vectorize this kernel by 8 and check that the verification pass
18-
; correctly picks up that we haven't satisfied the kernel's required sub-group
19-
; size.
20-
; FIXME: This is conflating vecz dimension and sub-group size but that's all we
21-
; can manage at the moment.
17+
; Try and forcibly vectorize this no-vecz kernel by 8 and check that the
18+
; verification pass correctly picks up that we haven't satisfied the kernel's
19+
; required sub-group size.
2220
; RUN: env CA_RISCV_VF=8 not muxc --device "%riscv_device" \
2321
; RUN: --passes "run-vecz,verify-reqd-sub-group-satisfied" %s 2>&1 \
2422
; RUN: | FileCheck %s
2523

26-
; CHECK: kernel.cl:10:0: kernel has required sub-group size 7 but the compiler was unable to sastify this constraint
27-
define void @foo_sg7() #0 !dbg !5 !intel_reqd_sub_group_size !2 {
24+
; CHECK: kernel.cl:10:0: kernel has required sub-group size 8 but the compiler was unable to sastify this constraint
25+
define void @foo_sg8() #0 !dbg !5 !intel_reqd_sub_group_size !2 {
2826
ret void
2927
}
3028

29+
attributes #0 = { "mux-kernel"="entry-point" "vecz-mode"="never" }
30+
3131
!llvm.dbg.cu = !{!0}
3232
!llvm.module.flags = !{!1}
3333

3434
!0 = distinct !DICompileUnit(language: DW_LANG_C99, file: !4, runtimeVersion: 0, emissionKind: FullDebug)
3535
!1 = !{i32 2, !"Debug Info Version", i32 3}
3636

37-
!2 = !{i32 7}
38-
!3 = !{i32 6}
37+
!2 = !{i32 8}
3938

4039
!4 = !DIFile(filename: "kernel.cl", directory: "/oneAPI")
4140
!5 = distinct !DISubprogram(name: "foo_sg7", scope: !4, file: !4, line: 10, scopeLine: 10, flags: DIFlagArtificial | DIFlagPrototyped, unit: !0)
42-
43-
attributes #0 = { "mux-kernel"="entry-point" }

modules/compiler/test/lit/passes/degenerate-sub-groups-cloning.ll

+2-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ entry:
3434
ret i32 %call
3535
}
3636

37-
; CHECK: declare spir_func i32 @__mux_work_group_reduce_add_i32(i32, i32)
3837

3938
; CHECK-LABEL: define spir_func i32 @sub_group_reduce_add_test.degenerate-subgroups
4039
; CHECK: (i32 [[Y:%.*]]) #[[ATTR0:[0-9]+]]
@@ -43,6 +42,8 @@ entry:
4342
; CHECK: ret i32 [[RESULT]]
4443
; CHECK: }
4544

45+
; CHECK: declare spir_func i32 @__mux_work_group_reduce_add_i32(i32, i32)
46+
4647
declare spir_func i32 @__mux_sub_group_reduce_add_i32(i32)
4748

4849
attributes #0 = { "mux-kernel"="entry-point" }

modules/compiler/test/lit/passes/degenerate-sub-groups-cloning2.ll

+2-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ entry:
7878
}
7979

8080
declare spir_func i32 @__mux_sub_group_reduce_add_i32(i32)
81-
; CHECK: declare spir_func i32 @__mux_work_group_reduce_add_i32(i32, i32)
8281

8382
; CHECK: define spir_func i32 @sub_groups.degenerate-subgroups(i32 [[X3:%.+]]) #[[ATTR2:[0-9]+]] {
8483
; CHECK: entry:
@@ -94,6 +93,8 @@ declare spir_func i32 @__mux_sub_group_reduce_add_i32(i32)
9493
; CHECK: ret i32 [[R1]]
9594
; CHECK: }
9695

96+
; CHECK: declare spir_func i32 @__mux_work_group_reduce_add_i32(i32, i32)
97+
9798
!opencl.ocl.version = !{!0}
9899

99100
!0 = !{i32 3, i32 0}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
; Copyright (C) Codeplay Software Limited
2+
;
3+
; Licensed under the Apache License, Version 2.0 (the "License") with LLVM
4+
; Exceptions; you may not use this file except in compliance with the License.
5+
; You may obtain a copy of the License at
6+
;
7+
; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt
8+
;
9+
; Unless required by applicable law or agreed to in writing, software
10+
; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
; License for the specific language governing permissions and limitations
13+
; under the License.
14+
;
15+
; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
16+
; RUN: muxc --passes degenerate-sub-groups,verify -S %s | FileCheck %s
17+
18+
; Check that the DegenerateSubGroupPass does not clone any kerenels with
19+
; required sub-group sizes.
20+
21+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
22+
target triple = "spir64-unknown-unknown"
23+
24+
; CHECK-NOT: {{(work_group|foo)}}
25+
26+
define spir_func i32 @clone_this(i32 %x) {
27+
entry:
28+
%call = call spir_func i32 @__mux_sub_group_reduce_add_i32(i32 %x)
29+
ret i32 %call
30+
}
31+
32+
define spir_func i32 @shared(i32 %x) {
33+
entry:
34+
%sqr = mul i32 %x, %x
35+
ret i32 %sqr
36+
}
37+
38+
define spir_func i32 @sub_groups(i32 %x) #0 !intel_reqd_sub_group_size !1 {
39+
entry:
40+
%call1 = call spir_func i32 @clone_this(i32 %x)
41+
%call2 = call spir_func i32 @shared(i32 %x)
42+
%add = add i32 %call1, %call2
43+
ret i32 %add
44+
}
45+
46+
define spir_func i32 @no_sub_groups(i32 %x) #0 !intel_reqd_sub_group_size !1 {
47+
entry:
48+
%call = call spir_func i32 @shared(i32 %x)
49+
ret i32 %call
50+
}
51+
52+
declare spir_func i32 @__mux_sub_group_reduce_add_i32(i32)
53+
54+
!opencl.ocl.version = !{!0}
55+
56+
!0 = !{i32 3, i32 0}
57+
!1 = !{i32 4}
58+
59+
attributes #0 = { "mux-kernel"="entry-point" }

modules/compiler/utils/include/compiler/utils/attributes.h

+5
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,11 @@ void setHasDegenerateSubgroups(llvm::Function &F);
172172
/// @param[in] F Function to check.
173173
bool hasDegenerateSubgroups(const llvm::Function &F);
174174

175+
/// @brief Returns the mux sub-group size for the current function.
176+
///
177+
/// Currently always returns 1!
178+
unsigned getMuxSubgroupSize(const llvm::Function &F);
179+
175180
} // namespace utils
176181
} // namespace compiler
177182

modules/compiler/utils/include/compiler/utils/vectorization_factor.h

+16
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,26 @@ class VectorizationFactor {
7676
/// factor represents.
7777
unsigned getKnownMin() const { return KnownMin; }
7878

79+
VectorizationFactor operator*(unsigned other) const {
80+
auto res = *this;
81+
res.KnownMin *= other;
82+
return res;
83+
}
84+
7985
bool operator==(const VectorizationFactor &other) const {
8086
return KnownMin == other.KnownMin && IsScalable == other.IsScalable;
8187
}
8288

89+
bool operator!=(const VectorizationFactor &other) const {
90+
return !operator==(other);
91+
}
92+
93+
bool operator==(unsigned other) const {
94+
return !IsScalable && KnownMin == other;
95+
}
96+
97+
bool operator!=(unsigned other) const { return !operator==(other); }
98+
8399
private:
84100
unsigned KnownMin = 1;
85101
bool IsScalable = false;

modules/compiler/utils/source/attributes.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -198,5 +198,11 @@ bool hasDegenerateSubgroups(const Function &F) {
198198
return Attr.isValid();
199199
}
200200

201+
unsigned getMuxSubgroupSize(const llvm::Function &) {
202+
// FIXME: The mux sub-group size is currently assumed to be 1 for all
203+
// functions, kerrnels, and targets. This helper function is just to avoid
204+
// hard-coding the constant 1 in places that will eventually need updated.
205+
return 1;
206+
}
201207
} // namespace utils
202208
} // namespace compiler

modules/compiler/utils/source/degenerate_sub_group_pass.cpp

+24-11
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,23 @@ std::optional<compiler::utils::Builtin> isSubGroupFunction(
6868
}
6969

7070
/// @return The work-group equivalent of the given builtin.
71-
Function *lookupWGBuiltin(const compiler::utils::Builtin &SGBuiltin,
72-
compiler::utils::BuiltinInfo &BI, Module &M) {
73-
compiler::utils::BuiltinID WGBuiltinID = compiler::utils::eBuiltinInvalid;
71+
compiler::utils::BuiltinID lookupWGBuiltinID(
72+
const compiler::utils::Builtin &SGBuiltin,
73+
compiler::utils::BuiltinInfo &BI) {
7474
if (SGBuiltin.ID == compiler::utils::eMuxBuiltinSubGroupBarrier) {
75-
WGBuiltinID = compiler::utils::eMuxBuiltinWorkGroupBarrier;
76-
} else {
77-
auto SGCollective = BI.isMuxGroupCollective(SGBuiltin.ID);
78-
assert(SGCollective.has_value() && "Not a sub-group builtin");
79-
auto WGCollective = *SGCollective;
80-
WGCollective.Scope = compiler::utils::GroupCollective::ScopeKind::WorkGroup;
81-
WGBuiltinID = BI.getMuxGroupCollective(WGCollective);
75+
return compiler::utils::eMuxBuiltinWorkGroupBarrier;
8276
}
77+
auto SGCollective = BI.isMuxGroupCollective(SGBuiltin.ID);
78+
assert(SGCollective.has_value() && "Not a sub-group builtin");
79+
auto WGCollective = *SGCollective;
80+
WGCollective.Scope = compiler::utils::GroupCollective::ScopeKind::WorkGroup;
81+
return BI.getMuxGroupCollective(WGCollective);
82+
}
83+
84+
/// @return The work-group equivalent of the given builtin.
85+
Function *lookupWGBuiltin(const compiler::utils::Builtin &SGBuiltin,
86+
compiler::utils::BuiltinInfo &BI, Module &M) {
87+
compiler::utils::BuiltinID WGBuiltinID = lookupWGBuiltinID(SGBuiltin, BI);
8388
// Not all sub-group builtins have a work-group equivalent.
8489
if (WGBuiltinID == compiler::utils::eBuiltinInvalid) {
8590
return nullptr;
@@ -268,6 +273,13 @@ PreservedAnalyses compiler::utils::DegenerateSubGroupPass::run(
268273
if (isKernelEntryPt(F)) {
269274
kernels.push_back(&F);
270275

276+
if (compiler::utils::getReqdSubgroupSize(F)) {
277+
// If there's a user-specified required sub-group size, we don't need to
278+
// clone this kernel. If vectorization fails to produce the right
279+
// sub-group size, we'll fail compilation.
280+
continue;
281+
}
282+
271283
auto const local_sizes = compiler::utils::getLocalSizeMetadata(F);
272284
if (!local_sizes) {
273285
// If we don't know the local size at compile time, we can't guarantee
@@ -337,7 +349,8 @@ PreservedAnalyses compiler::utils::DegenerateSubGroupPass::run(
337349
if (usesSubgroups.insert(&F).second) {
338350
worklist.push_back(&F);
339351
}
340-
if (SGBuiltin && !lookupWGBuiltin(*SGBuiltin, BI, M)) {
352+
if (SGBuiltin && lookupWGBuiltinID(*SGBuiltin, BI) ==
353+
compiler::utils::eBuiltinInvalid) {
341354
poisonList.insert(&F);
342355
}
343356
}

modules/compiler/utils/source/verify_reqd_sub_group_size_pass.cpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <compiler/utils/attributes.h>
1818
#include <compiler/utils/device_info.h>
1919
#include <compiler/utils/metadata.h>
20+
#include <compiler/utils/vectorization_factor.h>
2021
#include <compiler/utils/verify_reqd_sub_group_size_pass.h>
2122
#include <llvm/IR/DiagnosticInfo.h>
2223
#include <llvm/IR/DiagnosticPrinter.h>
@@ -94,14 +95,17 @@ PreservedAnalyses VerifyReqdSubGroupSizeSatisfiedPass::run(
9495
if (!ReqdSGSize) {
9596
continue;
9697
}
98+
99+
auto CurrSGSize = VectorizationFactor::getFixedWidth(
100+
compiler::utils::getMuxSubgroupSize(F));
97101
if (auto VeczInfo = parseVeczToOrigFnLinkMetadata(F)) {
98-
if (!VeczInfo->second.vf.isScalable() &&
99-
VeczInfo->second.vf.getKnownMin() == *ReqdSGSize) {
100-
continue;
101-
}
102+
CurrSGSize = VeczInfo->second.vf * CurrSGSize.getKnownMin();
103+
}
104+
105+
if (CurrSGSize != ReqdSGSize) {
106+
M.getContext().diagnose(DiagnosticInfoReqdSGSize(
107+
F, *ReqdSGSize, DiagnosticInfoReqdSGSize::DK_FailedReqdSGSize));
102108
}
103-
M.getContext().diagnose(DiagnosticInfoReqdSGSize(
104-
F, *ReqdSGSize, DiagnosticInfoReqdSGSize::DK_FailedReqdSGSize));
105109
}
106110

107111
return PreservedAnalyses::all();

0 commit comments

Comments
 (0)