Skip to content

Commit ea6ada0

Browse files
authored
Merge pull request #184 from frasercrmck/fix-vecz-sub-group-broadcast
[vecz] Fix vectorization of sub-group broadcasts
2 parents 7bea3c5 + c2fe3d1 commit ea6ada0

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

modules/compiler/vecz/source/transform/packetizer.cpp

+27-6
Original file line numberDiff line numberDiff line change
@@ -1249,17 +1249,17 @@ Value *Packetizer::Impl::packetizeGroupBroadcast(Instruction *I) {
12491249
if (SimdWidth.isScalable()) {
12501250
idxFactor = B.CreateVScale(minVal);
12511251
}
1252-
idx = B.CreateURem(idx, idxFactor);
1252+
auto *const vecIdx = B.CreateURem(idx, idxFactor);
12531253

12541254
Value *val = nullptr;
12551255
// Optimize the constant fixed-vector case, where we can choose the exact
12561256
// subpacket to extract from directly.
1257-
if (isa<ConstantInt>(idx) && !SimdWidth.isScalable()) {
1257+
if (isa<ConstantInt>(vecIdx) && !SimdWidth.isScalable()) {
12581258
ValuePacket opPackets;
12591259
op.getPacketValues(opPackets);
12601260
auto factor = SimdWidth.divideCoefficientBy(opPackets.size());
12611261
const unsigned subvecSize = factor.getFixedValue();
1262-
const unsigned idxVal = cast<ConstantInt>(idx)->getZExtValue();
1262+
const unsigned idxVal = cast<ConstantInt>(vecIdx)->getZExtValue();
12631263
// If individual elements are scalar (through instantiation, say) then just
12641264
// use the desired packet directly.
12651265
if (subvecSize == 1) {
@@ -1268,16 +1268,37 @@ Value *Packetizer::Impl::packetizeGroupBroadcast(Instruction *I) {
12681268
// Else extract from the correct packet, adjusting the index as we go.
12691269
val = B.CreateExtractElement(
12701270
opPackets[idxVal / subvecSize],
1271-
ConstantInt::get(idx->getType(), idxVal % subvecSize));
1271+
ConstantInt::get(vecIdx->getType(), idxVal % subvecSize));
12721272
}
12731273
} else {
1274-
val = B.CreateExtractElement(op.getAsValue(), idx);
1274+
val = B.CreateExtractElement(op.getAsValue(), vecIdx);
12751275
}
12761276

1277-
// We leave the origial broadcast function and divert the vectorized
1277+
// We leave the original broadcast function and divert the vectorized
12781278
// broadcast through it, giving us a broadcast over the full apparent
12791279
// sub-group or work-group size (vecz * mux).
12801280
CI->setOperand(argIdx, val);
1281+
if (!isWorkGroup) {
1282+
// For sub-groups, we need to normalize the sub-group ID into the range of
1283+
// mux sub-groups.
1284+
// |-----------------|-----------------|
1285+
// | broadcast(X, 6) | broadcast(A, 6) |
1286+
// VF=4 |-----------------|-----------------|
1287+
// | b(<X,Y,Z,W>, 6) | b(<A,B,C,D>, 6) |
1288+
// |-----------------|-----------------|
1289+
// M=I/4 | 1 | 1 |
1290+
// V=I%4 | 2 | 2 |
1291+
// |-----------------|-----------------|
1292+
// | <X,Y,Z,W>[V] | <A,B,C,D>[V] |
1293+
// | Z | C |
1294+
// |-----------------|-----------------|
1295+
// | broadcast(Z, M) | broadcast(C, M) |
1296+
// res | C | C |
1297+
// splat | <C,C,C,C> | <C,C,C,C> |
1298+
// |-----------------|-----------------|
1299+
auto *const muxIdx = B.CreateUDiv(idx, idxFactor);
1300+
CI->setOperand(argIdx + 1, muxIdx);
1301+
}
12811302

12821303
return CI;
12831304
}

modules/compiler/vecz/test/lit/llvm/subgroup_builtins.ll

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ define spir_kernel void @sub_group_broadcast_wider_than_vf(i32 addrspace(1)* %in
8383
; CHECK: [[LD:%.*]] = load <4 x i32>, ptr addrspace(1) {{%.*}}, align 4
8484
; The sixth sub-group member is the (6 % 4 ==) 2nd vector group member
8585
; CHECK: [[EXT:%.*]] = extractelement <4 x i32> [[LD]], i64 2
86-
; CHECK: [[BDCAST:%.*]] = call spir_func i32 @__mux_sub_group_broadcast_i32(i32 [[EXT]], i32 6)
86+
; CHECK: [[BDCAST:%.*]] = call spir_func i32 @__mux_sub_group_broadcast_i32(i32 [[EXT]], i32 1)
8787
; CHECK: [[HEAD:%.*]] = insertelement <4 x i32> poison, i32 [[BDCAST]], i64 0
8888
; CHECK: [[SPLAT:%.*]] = shufflevector <4 x i32> [[HEAD]], <4 x i32> {{(undef|poison)}}, <4 x i32> zeroinitializer
8989
; CHECK: store <4 x i32> [[SPLAT]], ptr addrspace(1)

0 commit comments

Comments
 (0)