Skip to content

Commit

Permalink
[vecz] Packetize the sub-group size on top of mux sub-groups
Browse files Browse the repository at this point in the history
The total vectorized sub-group size is the mux reduction sum of all
vectorized group sizes (i.e., the vectorization factor or vector
length).
  • Loading branch information
frasercrmck committed Aug 30, 2023
1 parent 619c065 commit 1db757b
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 16 deletions.
39 changes: 27 additions & 12 deletions modules/compiler/vecz/source/transform/packetizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,23 +694,38 @@ bool Packetizer::Impl::packetize() {
continue;
}

auto *const Callee = CI->getCalledFunction();
if (Callee && Ctx.builtins().analyzeBuiltin(*Callee).ID ==
if (auto *const Callee = CI->getCalledFunction();
Callee && Ctx.builtins().analyzeBuiltin(*Callee).ID ==
compiler::utils::eMuxBuiltinGetSubGroupSize) {
auto *const replacement = [this](CallInst *CI) -> Value * {
// The vectorized sub-group size is the mux sub-group reduction sum
// of all of the vectorized sub-group sizes:
// | mux 0 | mux 1 |
// | < a,b,c,d > | < e,f,g > (vl=3) |
// The total sub-group size above is 4 + 3 => 7.
// Note that this expects that the mux sub-group consists entirely of
// equivalently vectorized kernels.
Value *VecgroupSize;
IRBuilder<> B(CI);
auto *const I32Ty = B.getInt32Ty();
if (VL) {
return VL;
}

auto *const I32Ty = Type::getInt32Ty(F.getContext());
auto *const VFVal =
ConstantInt::get(I32Ty, SimdWidth.getKnownMinValue());
if (!SimdWidth.isScalable()) {
return VFVal;
VecgroupSize = VL;
} else {
IRBuilder<> B(CI);
return B.CreateVScale(VFVal);
auto *const VFVal = B.getInt32(SimdWidth.getKnownMinValue());
if (!SimdWidth.isScalable()) {
VecgroupSize = VFVal;
} else {
VecgroupSize = B.CreateVScale(VFVal);
}
}
assert(VecgroupSize && "Could not determine vector group size");

auto *ReduceFn = Ctx.builtins().getOrDeclareMuxBuiltin(
compiler::utils::eMuxBuiltinSubgroupReduceAdd, *F.getParent(),
{I32Ty});
assert(ReduceFn && "Could not get reduction builtin");

return B.CreateCall(ReduceFn, VecgroupSize, "subgroup.size");
}(CI);
CI->replaceAllUsesWith(replacement);
IC.deleteInstructionLater(CI);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ define spir_kernel void @get_sub_group_size(i32 addrspace(1)* %in, i32 addrspace
; CHECK-LABEL: define spir_kernel void @__vecz_nxv4_get_sub_group_size(
; CHECK: [[VSCALE:%.*]] = call i32 @llvm.vscale.i32()
; CHECK: [[W:%.*]] = shl i32 [[VSCALE]], 2
; CHECK: store i32 [[W]], ptr addrspace(1) {{.*}}
; CHECK: [[RED:%.*]] = call i32 @__mux_sub_group_reduce_add_i32(i32 [[W]])
; CHECK: store i32 [[RED]], ptr addrspace(1) {{.*}}
}

define spir_kernel void @get_sub_group_local_id(i32 addrspace(1)* %in, i32 addrspace(1)* %out) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ define spir_kernel void @get_sub_group_size(i32 addrspace(1)* %in, i32 addrspace
; CHECK-F2: [[WL:%.*]] = sub {{.*}} i64 [[SZ]], [[ID]]
; CHECK-F2: [[VL0:%.*]] = call i64 @llvm.umin.i64(i64 [[WL]], i64 2)
; CHECK-F2: [[VL1:%.*]] = trunc i64 [[VL0]] to i32
; CHECK-F2: store i32 [[VL1]], ptr addrspace(1) {{.*}}
; CHECK-F2: [[RED:%.*]] = call i32 @__mux_sub_group_reduce_add_i32(i32 [[VL1]])
; CHECK-F2: store i32 [[RED]], ptr addrspace(1) {{.*}}

; CHECK-S4-LABEL: define spir_kernel void @__vecz_nxv4_vp_get_sub_group_size(
; CHECK-S4: [[ID:%.*]] = call i64 @__mux_get_local_id(i32 0)
Expand All @@ -51,4 +52,5 @@ define spir_kernel void @get_sub_group_size(i32 addrspace(1)* %in, i32 addrspace
; CHECK-S4: [[VF1:%.*]] = shl i64 [[VF0]], 2
; CHECK-S4: [[VL0:%.*]] = call i64 @llvm.umin.i64(i64 [[WL]], i64 [[VF1]])
; CHECK-S4: [[VL1:%.*]] = trunc i64 [[VL0]] to i32
; CHECK-S4: store i32 [[VL1]], ptr addrspace(1) {{.*}}
; CHECK-S4: [[RED:%.*]] = call i32 @__mux_sub_group_reduce_add_i32(i32 [[VL1]])
; CHECK-S4: store i32 [[RED]], ptr addrspace(1) {{.*}}
3 changes: 2 additions & 1 deletion modules/compiler/vecz/test/lit/llvm/subgroup_builtins.ll
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ define spir_kernel void @get_sub_group_size(i32 addrspace(1)* %in, i32 addrspace
store i32 %call2, i32 addrspace(1)* %arrayidx, align 4
ret void
; CHECK-LABEL: define spir_kernel void @__vecz_v4_get_sub_group_size(
; CHECK: store i32 4, ptr addrspace(1) {{.*}}
; CHECK: [[RED:%.*]] = call i32 @__mux_sub_group_reduce_add_i32(i32 4)
; CHECK: store i32 [[RED]], ptr addrspace(1) {{.*}}
}

define spir_kernel void @get_sub_group_local_id(i32 addrspace(1)* %in, i32 addrspace(1)* %out) {
Expand Down

0 comments on commit 1db757b

Please sign in to comment.