@@ -1249,17 +1249,17 @@ Value *Packetizer::Impl::packetizeGroupBroadcast(Instruction *I) {
1249
1249
if (SimdWidth.isScalable ()) {
1250
1250
idxFactor = B.CreateVScale (minVal);
1251
1251
}
1252
- idx = B.CreateURem (idx, idxFactor);
1252
+ auto * const vecIdx = B.CreateURem (idx, idxFactor);
1253
1253
1254
1254
Value *val = nullptr ;
1255
1255
// Optimize the constant fixed-vector case, where we can choose the exact
1256
1256
// subpacket to extract from directly.
1257
- if (isa<ConstantInt>(idx ) && !SimdWidth.isScalable ()) {
1257
+ if (isa<ConstantInt>(vecIdx ) && !SimdWidth.isScalable ()) {
1258
1258
ValuePacket opPackets;
1259
1259
op.getPacketValues (opPackets);
1260
1260
auto factor = SimdWidth.divideCoefficientBy (opPackets.size ());
1261
1261
const unsigned subvecSize = factor.getFixedValue ();
1262
- const unsigned idxVal = cast<ConstantInt>(idx )->getZExtValue ();
1262
+ const unsigned idxVal = cast<ConstantInt>(vecIdx )->getZExtValue ();
1263
1263
// If individual elements are scalar (through instantiation, say) then just
1264
1264
// use the desired packet directly.
1265
1265
if (subvecSize == 1 ) {
@@ -1268,16 +1268,37 @@ Value *Packetizer::Impl::packetizeGroupBroadcast(Instruction *I) {
1268
1268
// Else extract from the correct packet, adjusting the index as we go.
1269
1269
val = B.CreateExtractElement (
1270
1270
opPackets[idxVal / subvecSize],
1271
- ConstantInt::get (idx ->getType (), idxVal % subvecSize));
1271
+ ConstantInt::get (vecIdx ->getType (), idxVal % subvecSize));
1272
1272
}
1273
1273
} else {
1274
- val = B.CreateExtractElement (op.getAsValue (), idx );
1274
+ val = B.CreateExtractElement (op.getAsValue (), vecIdx );
1275
1275
}
1276
1276
1277
- // We leave the origial broadcast function and divert the vectorized
1277
+ // We leave the original broadcast function and divert the vectorized
1278
1278
// broadcast through it, giving us a broadcast over the full apparent
1279
1279
// sub-group or work-group size (vecz * mux).
1280
1280
CI->setOperand (argIdx, val);
1281
+ if (!isWorkGroup) {
1282
+ // For sub-groups, we need to normalize the sub-group ID into the range of
1283
+ // mux sub-groups.
1284
+ // |-----------------|-----------------|
1285
+ // | broadcast(X, 6) | broadcast(A, 6) |
1286
+ // VF=4 |-----------------|-----------------|
1287
+ // | b(<X,Y,Z,W>, 6) | b(<A,B,C,D>, 6) |
1288
+ // |-----------------|-----------------|
1289
+ // M=I/4 | 1 | 1 |
1290
+ // V=I%4 | 2 | 2 |
1291
+ // |-----------------|-----------------|
1292
+ // | <X,Y,Z,W>[V] | <A,B,C,D>[V] |
1293
+ // | Z | C |
1294
+ // |-----------------|-----------------|
1295
+ // | broadcast(Z, M) | broadcast(C, M) |
1296
+ // res | C | C |
1297
+ // splat | <C,C,C,C> | <C,C,C,C> |
1298
+ // |-----------------|-----------------|
1299
+ auto *const muxIdx = B.CreateUDiv (idx, idxFactor);
1300
+ CI->setOperand (argIdx + 1 , muxIdx);
1301
+ }
1281
1302
1282
1303
return CI;
1283
1304
}
0 commit comments