Skip to content

Commit

Permalink
Merge pull request #145 from frasercrmck/replace-wgc-scans
Browse files Browse the repository at this point in the history
[compiler] Add work-group scan support in vecz/work-item-loops
  • Loading branch information
frasercrmck authored Oct 10, 2023
2 parents e1841a5 + f758c11 commit 7372f79
Show file tree
Hide file tree
Showing 15 changed files with 711 additions and 143 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ llvm::ModulePassManager RefSiG1PassMachinery::getLateTargetPasses() {
auto env_var_opts =
processOptimizationOptions(env_debug_prefix, /* vecz_mode*/ {});

// We don't run the WorkItemLoopsPass; we need an implementation of
// work-group collective operations.
tuner.replace_work_group_collectives = true;

PM.addPass(compiler::utils::TransferKernelMetadataPass());

if (env_debug_prefix) {
Expand Down Expand Up @@ -130,8 +134,6 @@ llvm::ModulePassManager RefSiG1PassMachinery::getLateTargetPasses() {
// addLateBuiltinsPasses, which isn't ideal.
PM.addPass(compiler::utils::DefineMuxDmaPass());

// We don't run the WorkItemLoopsPass; make sure that's taken into account.
tuner.handling_work_item_loops = false;
addPreVeczPasses(PM, tuner);

addLateBuiltinsPasses(PM, tuner);
Expand Down
5 changes: 3 additions & 2 deletions modules/compiler/source/base/include/base/pass_pipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ struct BasePassPipelineTuner {
/// @brief Whether or not to generate code for degenerate sub groups.
bool degenerate_sub_groups = false;

/// @brief Whether or not the WorkItemLoopsPass is going to be run.
bool handling_work_item_loops = true;
/// @brief Whether or not to replace work-group collectives early before
/// vectorization.
bool replace_work_group_collectives = false;

/// @brief The desired target calling convention, used to configure the
/// FixupCallingConvention pass.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,6 @@ Expected<bool> parseReplaceMuxMathDeclsPassOptions(StringRef Params) {
"ReplaceMuxMathDeclsPass");
}

Expected<bool> parseReplaceWGCPassOptions(StringRef Params) {
return compiler::utils::parseSinglePassOption(Params, "scans-only",
"ReplaceWGCPass");
}

// Lookup table for calling convention enums
std::unordered_map<std::string, CallingConv::ID> CallConvMap = {
{"C", CallingConv::C},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ MODULE_PASS("rename-builtins", compiler::utils::RenameBuiltinsPass())
MODULE_PASS("replace-atomic-funcs", compiler::utils::ReplaceAtomicFuncsPass())
MODULE_PASS("replace-c11-atomic-funcs",
compiler::utils::ReplaceC11AtomicFuncsPass())

MODULE_PASS("replace-module-scope-vars",
compiler::utils::ReplaceLocalModuleScopeVariablesPass())
MODULE_PASS("replace-wgc", compiler::utils::ReplaceWGCPass())

MODULE_PASS("builtin-simplify", compiler::BuiltinSimplificationPass())
MODULE_PASS("image-arg-subst", compiler::ImageArgumentSubstitutionPass())
Expand All @@ -70,14 +70,6 @@ MODULE_PASS("print<vecz-pass-opts>", vecz::VeczPassOptionsPrinterPass(dbgs()))
#define MODULE_PASS_WITH_PARAMS(NAME, CLASS, CREATE_PASS, PARSER, PARAMS)
#endif

MODULE_PASS_WITH_PARAMS(
"replace-wgc", "compiler::utils::ReplaceWGCPass",
[](bool ScansOnly) {
return compiler::utils::ReplaceWGCPass(ScansOnly);
},
parseReplaceWGCPassOptions,
"scans-only")

MODULE_PASS_WITH_PARAMS(
"add-kernel-wrapper", "compiler::utils::AddKernelWrapperPass",
[](compiler::utils::AddKernelWrapperPassOptions Options) {
Expand Down
11 changes: 5 additions & 6 deletions modules/compiler/source/base/source/pass_pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,11 @@ void addPreVeczPasses(ModulePassManager &PM,
PM.addPass(compiler::utils::DegenerateSubGroupPass());
}

// We need to use the software implementation of the work-group collective
// builtins. Because ReplaceWGCPass may introduce barrier calls it needs to be
// run before PrepareBarriersPass. When using the WorkItemLoopsPass, we can
// run the Replace WGC pass in Scans Only mode, since the WorkItemLoopsPass
// has its own implementations of reductions and broadcasts.
PM.addPass(compiler::utils::ReplaceWGCPass(tuner.handling_work_item_loops));
if (tuner.replace_work_group_collectives) {
// Because ReplaceWGCPass may introduce barrier calls it needs to be run
// before PrepareBarriersPass.
PM.addPass(compiler::utils::ReplaceWGCPass());
}

// We have to inline all functions containing barriers before running vecz,
// because the barriers in both the scalar and vector kernels need to be
Expand Down
64 changes: 28 additions & 36 deletions modules/compiler/test/lit/passes/replace-wgc.ll
Original file line number Diff line number Diff line change
Expand Up @@ -14,45 +14,39 @@
;
; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

; RUN: muxc --passes replace-wgc,verify < %s \
; RUN: | FileCheck %s --check-prefixes CHECK,CHECK-ALL
; RUN: muxc --passes "replace-wgc<scans-only>,verify" < %s \
; RUN: | FileCheck %s --check-prefixes CHECK,CHECK-SCANS-ONLY
; RUN: muxc --passes replace-wgc,verify < %s | FileCheck %s

; Check that the replace-wgc correctly defines the work-group collective
; functions, optionally only defining the scans and leaving others intact.

target triple = "spir64-unknown-unknown"
target datalayout = "e-p:64:64:64-m:e-i64:64-f80:128-n8:16:32:64-S128"

; CHECK-ALL: @__mux_work_group_reduce_smin_i32.accumulator = internal addrspace(3) global i32 undef
; CHECK: @__mux_work_group_reduce_smin_i32.accumulator = internal addrspace(3) global i32 undef
; CHECK: @__mux_work_group_scan_inclusive_umax_i32.accumulator = internal addrspace(3) global i32 undef
; CHECK: @__mux_work_group_scan_exclusive_fadd_f32.accumulator = internal addrspace(3) global float undef
; CHECK-ALL: @__mux_work_group_broadcast_i32.accumulator = internal addrspace(3) global i32 undef
; CHECK: @__mux_work_group_broadcast_i32.accumulator = internal addrspace(3) global i32 undef

; If this isn't a scan we shouldn't have defined it
; CHECK-SCANS-ONLY: declare spir_func i32 @__mux_work_group_reduce_smin_i32(i32, i32)
declare spir_func i32 @__mux_work_group_reduce_smin_i32(i32 %id, i32 %x)
; CHECK-ALL: define spir_func i32 @__mux_work_group_reduce_smin_i32(i32 %id, i32 [[PARAM:%.*]])
; CHECK-ALL-LABEL: entry:
; CHECK-ALL: %[[SUBGROUP:.+]] = call i32 @__mux_sub_group_reduce_smin_i32(i32 %{{.+}})
; CHECK-ALL: call void @__mux_work_group_barrier(i32 0, i32 2, i32 272) [[SCHEDULE_ONCE:#[0-9]+]]
; CHECK-ALL: store i32 2147483647, ptr addrspace(3) @__mux_work_group_reduce_smin_i32.accumulator
; CHECK-ALL: call void @__mux_work_group_barrier(i32 0, i32 2, i32 272)
; CHECK-ALL: %[[CURRVAL:.+]] = load i32, ptr addrspace(3) @__mux_work_group_reduce_smin_i32.accumulator
; CHECK-ALL: %[[ACCUM:.*]] = call i32 @llvm.smin.i32(i32 %[[CURRVAL]], i32 %[[SUBGROUP]])
; CHECK-ALL: store i32 %[[ACCUM]], ptr addrspace(3) @__mux_work_group_reduce_smin_i32.accumulator
; CHECK-ALL: call void @__mux_work_group_barrier(i32 0, i32 2, i32 272)
; CHECK-ALL: %[[RESULT:.*]] = load i32, ptr addrspace(3) @__mux_work_group_reduce_smin_i32.accumulator
; CHECK-ALL: ret i32 %[[RESULT]]
; CHECK: define spir_func i32 @__mux_work_group_reduce_smin_i32(i32 %id, i32 [[PARAM:%.*]])
; CHECK-LABEL: entry:
; CHECK: %[[SUBGROUP:.+]] = call i32 @__mux_sub_group_reduce_smin_i32(i32 %{{.+}})
; CHECK: call void @__mux_work_group_barrier(i32 0, i32 2, i32 272) [[SCHEDULE_ONCE:#[0-9]+]]
; CHECK: store i32 2147483647, ptr addrspace(3) @__mux_work_group_reduce_smin_i32.accumulator
; CHECK: call void @__mux_work_group_barrier(i32 0, i32 2, i32 272)
; CHECK: %[[CURRVAL:.+]] = load i32, ptr addrspace(3) @__mux_work_group_reduce_smin_i32.accumulator
; CHECK: %[[ACCUM:.*]] = call i32 @llvm.smin.i32(i32 %[[CURRVAL]], i32 %[[SUBGROUP]])
; CHECK: store i32 %[[ACCUM]], ptr addrspace(3) @__mux_work_group_reduce_smin_i32.accumulator
; CHECK: call void @__mux_work_group_barrier(i32 0, i32 2, i32 272)
; CHECK: %[[RESULT:.*]] = load i32, ptr addrspace(3) @__mux_work_group_reduce_smin_i32.accumulator
; CHECK: ret i32 %[[RESULT]]


declare spir_func i32 @__mux_work_group_scan_inclusive_umax_i32(i32 %id, i32 %x)
; CHECK: define spir_func i32 @__mux_work_group_scan_inclusive_umax_i32(i32 %id, i32 [[PARAM:%.*]])
; CHECK-LABEL: entry:
; This is just to ensure SCHEDULE_ONCE is defined on all paths...
; CHECK-ALL: call void @__mux_work_group_barrier(i32 0, i32 2, i32 272) [[SCHEDULE_ONCE]]
; CHECK-SCANS-ONLY: call void @__mux_work_group_barrier(i32 0, i32 2, i32 272) [[SCHEDULE_ONCE:#[0-9]+]]
; CHECK: call void @__mux_work_group_barrier(i32 0, i32 2, i32 272) [[SCHEDULE_ONCE]]
; CHECK: store i32 0, ptr addrspace(3) @__mux_work_group_scan_inclusive_umax_i32.accumulator
; CHECK: call void @__mux_work_group_barrier(i32 0, i32 2, i32 272) [[SCHEDULE_LINEAR:#[0-9]+]]
; CHECK: %[[CURRVAL:.+]] = load i32, ptr addrspace(3) @__mux_work_group_scan_inclusive_umax_i32.accumulator
Expand Down Expand Up @@ -98,21 +92,19 @@ declare spir_func float @__mux_work_group_scan_exclusive_fadd_f32(i32 %id, float
; CHECK: %[[RESULT:.+]] = select i1 %[[CMPXYZ]], float 0.000000e+00, float %[[WGSCAN]]
; CHECK: ret float %[[RESULT]]


; CHECK-SCANS-ONLY: declare spir_func i32 @__mux_work_group_broadcast_i32(i32, i32, i64, i64, i64)
declare spir_func i32 @__mux_work_group_broadcast_i32(i32 %barrier_id, i32 %x, i64 %idx, i64 %idy, i64 %idz)
; CHECK-ALL: define spir_func i32 @__mux_work_group_broadcast_i32(i32 %barrier_id, i32 [[PARAM:%.*]], i64 {{%.*}}, i64 {{%.*}}, i64 {{%.*}})
; CHECK-ALL-LABEL: entry:
; CHECK-ALL: call i64 @__mux_get_local_id(i32 0)

; CHECK-ALL-LABEL: broadcast:
; CHECK-ALL: store i32 [[PARAM]], ptr addrspace(3) @__mux_work_group_broadcast_i32.accumulator

; CHECK-ALL-LABEL: exit:
; CHECK-ALL: call void @__mux_work_group_barrier(i32 0, i32 2, i32 272)
; CHECK-ALL: [[RESULT:%.*]] = load i32, ptr addrspace(3) @__mux_work_group_broadcast_i32.accumulator
; CHECK-ALL: call void @__mux_work_group_barrier(i32 0, i32 2, i32 272)
; CHECK-ALL: ret i32 [[RESULT]]
; CHECK: define spir_func i32 @__mux_work_group_broadcast_i32(i32 %barrier_id, i32 [[PARAM:%.*]], i64 {{%.*}}, i64 {{%.*}}, i64 {{%.*}})
; CHECK-LABEL: entry:
; CHECK: call i64 @__mux_get_local_id(i32 0)

; CHECK-LABEL: broadcast:
; CHECK: store i32 [[PARAM]], ptr addrspace(3) @__mux_work_group_broadcast_i32.accumulator

; CHECK-LABEL: exit:
; CHECK: call void @__mux_work_group_barrier(i32 0, i32 2, i32 272)
; CHECK: [[RESULT:%.*]] = load i32, ptr addrspace(3) @__mux_work_group_broadcast_i32.accumulator
; CHECK: call void @__mux_work_group_barrier(i32 0, i32 2, i32 272)
; CHECK: ret i32 [[RESULT]]


declare spir_func half @__mux_work_group_scan_exclusive_fadd_f16(i32 %id, half %x)
Expand Down
57 changes: 57 additions & 0 deletions modules/compiler/test/lit/passes/work-item-loops-scan-1.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
; Copyright (C) Codeplay Software Limited
;
; Licensed under the Apache License, Version 2.0 (the "License") with LLVM
; Exceptions; you may not use this file except in compliance with the License.
; You may obtain a copy of the License at
;
; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt
;
; Unless required by applicable law or agreed to in writing, software
; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
; License for the specific language governing permissions and limitations
; under the License.
;
; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

; RUN: muxc --passes work-item-loops,verify < %s | FileCheck %s

target triple = "spir64-unknown-unknown"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"

define void @foo(ptr %in, ptr %out) #0 {
entry:
%id = call i64 @__mux_get_local_id(i32 0)
%inaddr = getelementptr inbounds i32, ptr %in, i64 %id
%val = load i32, ptr %inaddr
%scan = tail call i32 @__mux_work_group_scan_inclusive_mul_i32(i32 0, i32 %val) #4
%outaddr = getelementptr inbounds i32, ptr %out, i64 %id
store i32 %scan, ptr %outaddr
ret void
}

; CHECK: define void @foo.mux-barrier-wrapper(ptr %in, ptr %out)

; CHECK-LABEL: loopIR7:
; CHECK: [[PHIZ:%.*]] = phi i32 [ 1, %sw.bb2 ], [ [[ACC:%.*]], %exitIR11 ]

; CHECK-LABEL: loopIR8:
; CHECK: [[PHIY:%.*]] = phi i32 [ [[PHIZ]], %loopIR7 ], [ [[ACC]], %exitIR10 ]

; CHECK-LABEL: loopIR9:
; CHECK: [[PHIX:%.*]] = phi i32 [ [[PHIY]], %loopIR8 ], [ [[ACC]], %loopIR9 ]
; CHECK: [[VAL:%.*]] = load i32, ptr %live_gep_val, align 4
; CHECK: [[ACC]] = mul i32 [[PHIX]], [[VAL]]
; CHECK: call i32 @foo.mux-barrier-region.1(ptr %in, ptr %out, i32 [[ACC]],

declare i32 @__mux_work_group_scan_inclusive_mul_i32(i32, i32) #1

declare i64 @__mux_get_local_id(i32) #2

attributes #0 = { convergent norecurse nounwind "mux-kernel"="entry-point" }
attributes #1 = { alwaysinline convergent norecurse nounwind }
attributes #2 = { alwaysinline norecurse nounwind readonly }

!opencl.ocl.version = !{!0}

!0 = !{i32 3, i32 0}
110 changes: 110 additions & 0 deletions modules/compiler/test/lit/passes/work-item-loops-scan-2.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
; Copyright (C) Codeplay Software Limited
;
; Licensed under the Apache License, Version 2.0 (the "License") with LLVM
; Exceptions; you may not use this file except in compliance with the License.
; You may obtain a copy of the License at
;
; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt
;
; Unless required by applicable law or agreed to in writing, software
; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
; License for the specific language governing permissions and limitations
; under the License.
;
; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

; RUN: muxc --passes work-item-loops,verify < %s | FileCheck %s

target triple = "spir64-unknown-unknown"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"

define void @foo(ptr %in, ptr %out) #0 !codeplay_ca_vecz.base !1 {
entry:
%id = call i64 @__mux_get_local_id(i32 0)
%inaddr = getelementptr inbounds i32, ptr %in, i64 %id
%val = load i32, ptr %inaddr, align 4
%scan = tail call i32 @__mux_work_group_scan_inclusive_mul_i32(i32 0, i32 %val)
%outaddr = getelementptr inbounds i32, ptr %out, i64 %id
store i32 %scan, ptr %outaddr, align 4
ret void
}

declare i32 @__mux_work_group_scan_inclusive_mul_i32(i32, i32) #1

declare i64 @__mux_get_local_id(i32) #2

define void @__vecz_nxv4_foo(ptr %in, ptr %out) #3 !codeplay_ca_vecz.derived !3 {
entry:
%id = call i64 @__mux_get_local_id(i32 0)
%inaddr = getelementptr inbounds i32, ptr %in, i64 %id
%0 = load <vscale x 4 x i32>, ptr %inaddr, align 4
%1 = call <vscale x 4 x i32> @__vecz_b_sub_group_scan_inclusive_mul_u5nxv4j(<vscale x 4 x i32> %0) #6
%2 = call i32 @llvm.vector.reduce.mul.nxv4i32(<vscale x 4 x i32> %0)
%3 = call i32 @__mux_work_group_scan_exclusive_mul_i32(i32 0, i32 %2)
%.splatinsert = insertelement <vscale x 4 x i32> poison, i32 %3, i64 0
%.splat = shufflevector <vscale x 4 x i32> %.splatinsert, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
%4 = mul <vscale x 4 x i32> %1, %.splat
%outaddr = getelementptr inbounds i32, ptr %out, i64 %id
store <vscale x 4 x i32> %4, ptr %outaddr, align 4
ret void
}

; CHECK: define void @__vecz_nxv4_foo.mux-barrier-wrapper(ptr %in, ptr %out)

; Check a linear loop structure, looping over all Zs and Ys, and in the inner
; loop doing the main and tail X items in sequence.
; CHECK-LABEL: loopIR11:
; CHECK: [[PHIZ:%.*]] = phi i32 [ 1, %sw.bb2 ], [ [[TAIL_MERGE:%.*]], %exitIR15 ]

; CHECK-LABEL: loopIR12:
; CHECK: [[PHIY:%.*]] = phi i32 [ [[PHIZ]], %loopIR11 ], [ [[TAIL_MERGE]], %ca_work_item_x_tail_exit ]

; Main loop
; CHECK-LABEL: loopIR13:
; CHECK: [[MPHIX:%.*]] = phi i32 [ [[PHIY]], %ca_work_item_x_main_preheader ], [ [[MACC:%.*]], %loopIR13 ]
; CHECK: [[MVAL:%.*]] = load i32, ptr %live_gep_, align 4
; CHECK: [[MACC]] = mul i32 [[MPHIX]], [[MVAL]]
; This is an exclusive scan, so pass the 'previous' value to the sub-kernel.
; CHECK: call i32 @__vecz_nxv4_foo.mux-barrier-region.1(ptr %in, ptr %out, i32 [[MPHIX]],

; CHECK-LABEL: ca_work_item_x_main_exit:
; CHECK: [[MERGE:%.*]] = phi i32 [ [[PHIY]], %loopIR12 ], [ [[MACC]], %loopIR13 ]

; Tail loop
; CHECK-LABEL: loopIR14:
; CHECK: [[TPHIX:%.*]] = phi i32 [ [[MERGE]], %ca_work_item_x_tail_preheader ],
; CHECK-SAME: [ [[TACC:%.*]], %loopIR14 ]
; CHECK: [[TVAL:%.*]] = load i32, ptr %live_gep_val, align 4
; CHECK: [[TACC]] = mul i32 [[TPHIX]], [[TVAL]]
; This is an inclusive scan, so pass the 'current' value to the sub-kernel.
; CHECK: call i32 @foo.mux-barrier-region.2(ptr %in, ptr %out, i32 [[TACC]],

; Note - vecz normally generates the body for this helper, but it's irrelevant
; for this test
declare <vscale x 4 x i32> @__vecz_b_sub_group_scan_inclusive_mul_u5nxv4j(<vscale x 4 x i32> %0)

declare i32 @llvm.vector.reduce.mul.nxv4i32(<vscale x 4 x i32>) #4

declare i32 @__mux_work_group_scan_exclusive_mul_i32(i32, i32) #1

declare <vscale x 4 x i32> @llvm.experimental.stepvector.nxv4i32() #4

declare i32 @llvm.vscale.i32() #4

declare <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0(<vscale x 4 x ptr>, i32 immarg, <vscale x 4 x i1>, <vscale x 4 x i32>) #5

attributes #0 = { convergent norecurse nounwind "mux-kernel"="entry-point" }
attributes #1 = { alwaysinline convergent norecurse nounwind }
attributes #2 = { alwaysinline norecurse nounwind readonly }
attributes #3 = { convergent norecurse nounwind "mux-base-fn-name"="__vecz_nxv4_foo" "mux-kernel"="entry-point" }
attributes #4 = { nocallback nofree nosync nounwind willreturn readnone }
attributes #5 = { nocallback nofree nosync nounwind willreturn readonly }
attributes #6 = { nounwind }

!opencl.ocl.version = !{!0}

!0 = !{i32 3, i32 0}
!1 = !{!2, ptr @__vecz_nxv4_foo}
!2 = !{i32 4, i32 1, i32 0, i32 0}
!3 = !{!2, ptr @foo}
Loading

0 comments on commit 7372f79

Please sign in to comment.