Skip to content

Commit

Permalink
Merge pull request #113 from frasercrmck/vecz-sub-groups
Browse files Browse the repository at this point in the history
[vecz] Vectorize sub-groups on top of mux sub-groups
  • Loading branch information
frasercrmck authored Aug 31, 2023
2 parents b1ee1e5 + 1db757b commit 5eda28c
Show file tree
Hide file tree
Showing 12 changed files with 399 additions and 153 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,6 @@ Value *processCallSite(CallInst *CI, bool &NeedLLVMInline,
}
}

// Vectorized uses of the subgroup local id will have been replaced with step
// vectors starting from zero. Uniform uses should be replaced with zero in
// order to maintain equivalence between the scalar/vector forms. Do this
// here due to a tight coupling between the vectorized version and these
// remaining scalar versions.
if (Builtin.ID == compiler::utils::eMuxBuiltinGetSubGroupLocalId) {
return ConstantInt::getNullValue(CI->getType());
}

return CI;
}

Expand Down
202 changes: 132 additions & 70 deletions modules/compiler/vecz/source/transform/packetizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,18 +190,18 @@ class Packetizer::Impl : public Packetizer {
///
/// @return Packetized instruction.
Value *packetizeMaskVarying(Instruction *I);
/// @brief Packetize a mask-varying subgroup reduction.
/// @brief Packetize a mask-varying subgroup/workgroup reduction.
///
/// @param[in] I Instruction to packetize.
///
/// @return Packetized instruction.
Value *packetizeSubgroupReduction(Instruction *I);
/// @brief Packetize a subgroup broadcast.
Value *packetizeGroupReduction(Instruction *I);
/// @brief Packetize a subgroup/workgroup broadcast.
///
/// @param[in] I Instruction to packetize.
///
/// @return Packetized instruction.
Value *packetizeSubgroupBroadcast(Instruction *I);
Value *packetizeGroupBroadcast(Instruction *I);
/// @brief Packetize PHI node.
///
/// @param[in] Phi PHI Node to packetize.
Expand Down Expand Up @@ -693,23 +693,38 @@ bool Packetizer::Impl::packetize() {
continue;
}

auto *const Callee = CI->getCalledFunction();
if (Callee && Ctx.builtins().analyzeBuiltin(*Callee).ID ==
if (auto *const Callee = CI->getCalledFunction();
Callee && Ctx.builtins().analyzeBuiltin(*Callee).ID ==
compiler::utils::eMuxBuiltinGetSubGroupSize) {
auto *const replacement = [this](CallInst *CI) -> Value * {
// The vectorized sub-group size is the mux sub-group reduction sum
// of all of the vectorized sub-group sizes:
// | mux 0 | mux 1 |
// | < a,b,c,d > | < e,f,g > (vl=3) |
// The total sub-group size above is 4 + 3 => 7.
// Note that this expects that the mux sub-group consists entirely of
// equivalently vectorized kernels.
Value *VecgroupSize;
IRBuilder<> B(CI);
auto *const I32Ty = B.getInt32Ty();
if (VL) {
return VL;
}

auto *const I32Ty = Type::getInt32Ty(F.getContext());
auto *const VFVal =
ConstantInt::get(I32Ty, SimdWidth.getKnownMinValue());
if (!SimdWidth.isScalable()) {
return VFVal;
VecgroupSize = VL;
} else {
IRBuilder<> B(CI);
return B.CreateVScale(VFVal);
auto *const VFVal = B.getInt32(SimdWidth.getKnownMinValue());
if (!SimdWidth.isScalable()) {
VecgroupSize = VFVal;
} else {
VecgroupSize = B.CreateVScale(VFVal);
}
}
assert(VecgroupSize && "Could not determine vector group size");

auto *ReduceFn = Ctx.builtins().getOrDeclareMuxBuiltin(
compiler::utils::eMuxBuiltinSubgroupReduceAdd, *F.getParent(),
{I32Ty});
assert(ReduceFn && "Could not get reduction builtin");

return B.CreateCall(ReduceFn, VecgroupSize, "subgroup.size");
}(CI);
CI->replaceAllUsesWith(replacement);
IC.deleteInstructionLater(CI);
Expand Down Expand Up @@ -881,11 +896,11 @@ Packetizer::Result Packetizer::Impl::packetize(Value *V) {
return getPacketized(Ins);
}

if (auto *reduction = packetizeSubgroupReduction(Ins)) {
if (auto *reduction = packetizeGroupReduction(Ins)) {
return broadcast(reduction);
}

if (auto *brdcast = packetizeSubgroupBroadcast(Ins)) {
if (auto *brdcast = packetizeGroupBroadcast(Ins)) {
return broadcast(brdcast);
}

Expand Down Expand Up @@ -1074,7 +1089,7 @@ Packetizer::Result Packetizer::Impl::packetizeInstruction(Instruction *Ins) {
return Packetizer::Result(*this, Ins, nullptr);
}

Value *Packetizer::Impl::packetizeSubgroupReduction(Instruction *I) {
Value *Packetizer::Impl::packetizeGroupReduction(Instruction *I) {
auto *const CI = dyn_cast<CallInst>(I);
if (!CI || !CI->getCalledFunction()) {
return nullptr;
Expand Down Expand Up @@ -1122,18 +1137,19 @@ Value *Packetizer::Impl::packetizeSubgroupReduction(Instruction *I) {
Value *&val = opPackets.front();
val = sanitizeVPReductionInput(B, val, VL, Info->Recurrence);
if (!val) {
emitVeczRemarkMissed(&F, CI,
"Can not vector-predicate subgroup reduction");
emitVeczRemarkMissed(
&F, CI, "Can not vector-predicate workgroup/subgroup reduction");
return nullptr;
}
}

// According to the OpenCL Spec, we are allowed to rearrange the operation
// order of a subgroup reduction any way we like (even though floating point
// addition is not associative so might not produce exactly the same result),
// so we reduce to a single vector first, if necessary, and then do a single
// reduction to scalar. This is more efficient than doing multiple reductions
// to scalar and then BinOp'ing multiple scalars together.
// order of a workgroup/subgroup reduction any way we like (even though
// floating point addition is not associative so might not produce exactly
// the same result), so we reduce to a single vector first, if necessary, and
// then do a single reduction to scalar. This is more efficient than doing
// multiple reductions to scalar and then BinOp'ing multiple scalars
// together.
//
// Reduce to a single vector.
while ((packetWidth >>= 1)) {
Expand All @@ -1149,21 +1165,15 @@ Value *Packetizer::Impl::packetizeSubgroupReduction(Instruction *I) {
Value *v =
createSimpleTargetReduction(B, &TTI, opPackets.front(), Info->Recurrence);

if (isWorkGroup) {
// For a work group operation, we leave the original reduction function and
// divert the subgroup reduction through it, giving us a work group
// reduction over subgroup reductions.
CI->setOperand(argIdx, v);
v = CI;
} else {
IC.deleteInstructionLater(CI);
CI->replaceAllUsesWith(v);
}
// We leave the original reduction function and divert the vectorized
// reduction through it, giving us a reduction over the full apparent
// sub-group or work-group size (vecz * mux).
CI->setOperand(argIdx, v);

return v;
return CI;
}

Value *Packetizer::Impl::packetizeSubgroupBroadcast(Instruction *I) {
Value *Packetizer::Impl::packetizeGroupBroadcast(Instruction *I) {
auto *const CI = dyn_cast<CallInst>(I);
if (!CI || !CI->getCalledFunction()) {
return nullptr;
Expand Down Expand Up @@ -1200,17 +1210,15 @@ Value *Packetizer::Impl::packetizeSubgroupBroadcast(Instruction *I) {
}

auto *idx = CI->getArgOperand(argIdx + 1);
if (isWorkGroup) {
// When it's a work group broadcast, we need to sanitize the input index so
// that it stays within the range of one subgroup.
auto *const minVal =
ConstantInt::get(idx->getType(), SimdWidth.getKnownMinValue());
Value *idxFactor = minVal;
if (SimdWidth.isScalable()) {
idxFactor = B.CreateVScale(minVal);
}
idx = B.CreateURem(idx, idxFactor);
// We need to sanitize the input index so that it stays within the range of
// one vectorized group.
auto *const minVal =
ConstantInt::get(idx->getType(), SimdWidth.getKnownMinValue());
Value *idxFactor = minVal;
if (SimdWidth.isScalable()) {
idxFactor = B.CreateVScale(minVal);
}
idx = B.CreateURem(idx, idxFactor);

Value *val = nullptr;
// Optimize the constant fixed-vector case, where we can choose the exact
Expand All @@ -1235,18 +1243,12 @@ Value *Packetizer::Impl::packetizeSubgroupBroadcast(Instruction *I) {
val = B.CreateExtractElement(op.getAsValue(), idx);
}

if (isWorkGroup) {
// For a work group operation, we leave the origial broadcast function and
// divert the subgroup reduction through it, giving us a work group
// reduction over subgroup reductions.
CI->setOperand(argIdx, val);
val = CI;
} else {
IC.deleteInstructionLater(CI);
CI->replaceAllUsesWith(val);
}
// We leave the origial broadcast function and divert the vectorized
// broadcast through it, giving us a broadcast over the full apparent
// sub-group or work-group size (vecz * mux).
CI->setOperand(argIdx, val);

return val;
return CI;
}

Value *Packetizer::Impl::packetizeMaskVarying(Instruction *I) {
Expand Down Expand Up @@ -1699,9 +1701,62 @@ ValuePacket Packetizer::Impl::packetizeSubgroupScan(

IRBuilder<> B(CI);

auto *c = B.CreateCall(SubgroupFn, Ops);
auto *VectorScan = B.CreateCall(SubgroupFn, Ops);

// We've currently got a scan over each vector group, but the full sub-group
// is further multiplied by the mux sub-group size. For example, we may have
// a vectorization factor sized group of 4 and a mux sub-group size of 2.
// Together the full sub-group size to the user is 4*2 = 8.
// In terms of invocations, we've essentially currently got:
// <a0, a0+a1, a0+a1+a2, a0+a1+a2+a3> (invocation 0)
// <a4, a4+a5, a4+a5+a6, a4+a5+a6+a7> (invocation 1)
// These two iterations need to be further scanned over the mux sub-group
// size. We do this by adding the identity to the first invocation, the
// result of the scan over the first invocation to the second, etc. This is
// an exclusive scan over the *reduction* of the input vector:
// <a0, a1, a2, a3> (invocation 0)
// <a4, a5, a6, a7> (invocation 1)
// -> reduction
// (a0+a1+a2+a3) (invocation 0)
// (a4+a5+a6+a7) (invocation 1)
// -> exclusive mux sub-group scan
// I (invocation 0)
// (a0+a1+a2+a3) (invocation 1)
// -> adding that to the result of the vector scan:
// <I+a0, I+a0+a1, I+a0+a1+a2, I+a0+a1+a2+a3> (invocation 0)
// <(a0+a1+a2+a3)+a4, (a0+a1+a2+a3)+a4+a5, (invocation 1)
// (a0+a1+a2+a3)+a4+a5+a6, (a0+a1+a2+a3)+a4+a5+a6+a7>
// When viewed as a full 8-element vector, this is our final scan.
// Thus we essentially keep the original mux sub-group scan, but change it to
// be an exclusive one.
auto *Reduction = Ops.front();
if (VL) {
Reduction = sanitizeVPReductionInput(B, Reduction, VL, Scan.Recurrence);
if (!Reduction) {
return results;
}
}
Reduction = createSimpleTargetReduction(B, &TTI, Reduction, Scan.Recurrence);

// Now we defer to an *exclusive* scan over the mux sub-group.
auto ExclScan = Scan;
ExclScan.Op = compiler::utils::GroupCollective::OpKind::ScanExclusive;

auto ExclScanID = Ctx.builtins().getMuxGroupCollective(ExclScan);
assert(ExclScanID != compiler::utils::eBuiltinInvalid);

auto *const ExclScanFn = Ctx.builtins().getOrDeclareMuxBuiltin(
ExclScanID, *F.getParent(), {CI->getType()});
assert(ExclScanFn);

auto *const ExclScanCI = B.CreateCall(ExclScanFn, {Reduction});

results.push_back(c);
Value *const Splat = B.CreateVectorSplat(SimdWidth, ExclScanCI);

auto *const Result = multi_llvm::createBinOpForRecurKind(B, VectorScan, Splat,
Scan.Recurrence);

results.push_back(Result);
return results;
}

Expand Down Expand Up @@ -2443,14 +2498,6 @@ Value *Packetizer::Impl::vectorizeCall(CallInst *CI) {
return nullptr;
}
if (Builtin.properties & compiler::utils::eBuiltinPropertyWorkItem) {
// The subgroup ID is just a simple index sequence. There is no dimension
// to it, and we only support 1D workgroups.
if (Builtin.ID == compiler::utils::eMuxBuiltinGetSubGroupLocalId) {
IRBuilder<> B(buildAfter(CI, F));
return multi_llvm::createIndexSequence(
B, VectorType::get(CI->getType(), SimdWidth), SimdWidth,
"subgroup.local.id");
}
return vectorizeWorkGroupCall(CI, Builtin);
}

Expand Down Expand Up @@ -2561,8 +2608,23 @@ Value *Packetizer::Impl::vectorizeWorkGroupCall(
// Do not vectorize ranks equal to vectorization dimension. The value of
// get_global_id with other ranks is uniform.

Value *IDToSplat = CI;
// Multiply the sub-group local ID by the vectorization factor, to vectorize
// across the entire sub-group size.
// For example, with a vector width of 4 and a mux sub-group size of 2, the
// apparent sub-group size is 8 and the sub-group IDs are:
// | mux sub group 0 | mux sub group 1 |
// |-----------------|-----------------|
// | 0 1 2 3 | 4 5 6 7 |
if (Builtin.ID == compiler::utils::eMuxBuiltinGetSubGroupLocalId) {
auto SimdWithAsVal = B.getInt32(SimdWidth.getKnownMinValue());
IDToSplat = B.CreateMul(IDToSplat, !SimdWidth.isScalable()
? SimdWithAsVal
: B.CreateVScale(SimdWithAsVal));
}

// Broadcast the builtin's return value.
Value *Splat = B.CreateVectorSplat(SimdWidth, CI);
Value *Splat = B.CreateVectorSplat(SimdWidth, IDToSplat);

// Add an index sequence [0, 1, 2, ...] to the value unless uniform.
auto const Uniformity = Builtin.uniformity;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ define spir_kernel void @get_sub_group_size(i32 addrspace(1)* %in, i32 addrspace
; CHECK-LABEL: define spir_kernel void @__vecz_nxv4_get_sub_group_size(
; CHECK: [[VSCALE:%.*]] = call i32 @llvm.vscale.i32()
; CHECK: [[W:%.*]] = shl i32 [[VSCALE]], 2
; CHECK: store i32 [[W]], ptr addrspace(1) {{.*}}
; CHECK: [[RED:%.*]] = call i32 @__mux_sub_group_reduce_add_i32(i32 [[W]])
; CHECK: store i32 [[RED]], ptr addrspace(1) {{.*}}
}

define spir_kernel void @get_sub_group_local_id(i32 addrspace(1)* %in, i32 addrspace(1)* %out) {
Expand All @@ -44,8 +45,17 @@ define spir_kernel void @get_sub_group_local_id(i32 addrspace(1)* %in, i32 addrs
store i32 %call, i32 addrspace(1)* %arrayidx, align 4
ret void
; CHECK-LABEL: define spir_kernel void @__vecz_nxv4_get_sub_group_local_id(
; CHECK: [[LID:%.*]] = call <vscale x 4 x i32> @llvm.experimental.stepvector.nxv4i32()
; CHECK: store <vscale x 4 x i32> [[LID]], ptr addrspace(1) %out
; CHECK: %call = tail call spir_func i32 @__mux_get_sub_group_local_id()
; CHECK: [[VSCALE:%.*]] = call i32 @llvm.vscale.i32()
; CHECK: [[SHL:%.*]] = shl i32 %1, 2
; CHECK: [[MUL:%.*]] = mul i32 %call, [[SHL]]
; CHECK: [[SPLATINSERT:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[MUL]], i64 0
; CHECK: [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[SPLATINSERT]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
; CHECK: [[STEPVEC:%.*]] = call <vscale x 4 x i32> @llvm.experimental.stepvector.nxv4i32()
; CHECK: [[LID:%.*]] = add <vscale x 4 x i32> [[SPLAT]], [[STEPVEC]]
; CHECK: [[EXT:%.*]] = sext i32 %call to i64
; CHECK: %arrayidx = getelementptr inbounds i32, ptr addrspace(1) %out, i64 [[EXT]]
; CHECK: store <vscale x 4 x i32> [[LID]], ptr addrspace(1) %arrayidx
}

define spir_kernel void @sub_group_broadcast(i32 addrspace(1)* %in, i32 addrspace(1)* %out) {
Expand All @@ -59,7 +69,8 @@ define spir_kernel void @sub_group_broadcast(i32 addrspace(1)* %in, i32 addrspac
; CHECK-LABEL: define spir_kernel void @__vecz_nxv4_sub_group_broadcast(
; CHECK: [[LD:%.*]] = load <vscale x 4 x i32>, ptr addrspace(1) {{%.*}}, align 4
; CHECK: [[EXT:%.*]] = extractelement <vscale x 4 x i32> [[LD]], {{(i32|i64)}} 0
; CHECK: [[INS:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[EXT]], {{(i32|i64)}} 0
; CHECK: [[BDCAST:%.*]] = call spir_func i32 @__mux_sub_group_broadcast_i32(i32 [[EXT]], i32 0)
; CHECK: [[INS:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[BDCAST]], {{(i32|i64)}} 0
; CHECK: [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[INS]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
; CHECK: store <vscale x 4 x i32> [[SPLAT]], ptr addrspace(1)
}
Expand Down
Loading

0 comments on commit 5eda28c

Please sign in to comment.