Skip to content

Commit

Permalink
Merge pull request #183 from frasercrmck/vecz-subgroup-shuffle-up-down
Browse files Browse the repository at this point in the history
[vecz] Packetize sub-group shuffle_(up|down) builtins
  • Loading branch information
frasercrmck authored Nov 6, 2023
2 parents ea6ada0 + 3b3882b commit e2f6294
Show file tree
Hide file tree
Showing 3 changed files with 686 additions and 36 deletions.
320 changes: 297 additions & 23 deletions modules/compiler/vecz/source/transform/packetizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ class Packetizer::Impl : public Packetizer {
/// @return Packetized values.
ValuePacket packetizeAndGet(Value *V, unsigned Width);

/// @brief Helper to produce a Result from a Packet
Packetizer::Result getPacketizationResult(
Instruction *I, const SmallVectorImpl<Value *> &Packet,
bool UpdateStats = false);

/// @brief Packetize the given value from the function, only if it is a
/// varying value. Ensures Mask Varying values are handled correctly.
///
Expand Down Expand Up @@ -233,6 +238,17 @@ class Packetizer::Impl : public Packetizer {
/// @return Packetized instructions.
Result packetizeSubgroupShuffleXor(
Instruction *Ins, compiler::utils::GroupCollective ShuffleXor);
/// @brief Packetize a sub-group shuffle-up or shuffle-down builtin
///
/// Note - not any shuffle-like operation, but specifically the 'shuffle_up'
/// and 'shuffle_down' builtins.
///
/// @param[in] Ins Instruction to packetize.
/// @param[in] ShuffleUpDown Shuffle to packetize.
///
/// @return Packetized instructions.
Result packetizeSubgroupShuffleUpDown(
Instruction *Ins, compiler::utils::GroupCollective ShuffleUpDown);

/// @brief Packetize PHI node.
///
Expand Down Expand Up @@ -803,6 +819,41 @@ PacketRange Packetizer::createPacket(Value *V, unsigned width) {
return Result(*this, V, &info).createPacket(width);
}

Packetizer::Result Packetizer::Impl::getPacketizationResult(
Instruction *I, const SmallVectorImpl<Value *> &Packet, bool UpdateStats) {
if (Packet.empty()) {
return Result(*this);
}
auto PacketWidth = Packet.size();

// If there's only one value in the packet, we can assign the new packetized
// value to the old instruction directly.
if (PacketWidth == 1) {
Value *Vec = Packet.front();
if (Vec != I) {
// Only delete if the vectorized value is different from the scalar.
IC.deleteInstructionLater(I);
}
vectorizeDI(I, Vec);
return assign(I, Vec);
}

// Otherwise we have to create a 'Result' out of the packetized values.
IC.deleteInstructionLater(I);
auto &Info = packets[I];
auto Res = Result(*this, I, &Info);
auto P = Res.createPacket(PacketWidth);
for (unsigned i = 0; i < PacketWidth; ++i) {
P[i] = Packet[i];
}

if (UpdateStats) {
++VeczPacketized;
}
Info.numInstances = PacketWidth;
return Res;
}

Value *Packetizer::Impl::reduceBranchCond(Value *cond, Instruction *terminator,
bool allOf) {
// Get the branch condition at its natural packet width
Expand Down Expand Up @@ -938,6 +989,12 @@ Packetizer::Result Packetizer::Impl::packetize(Value *V) {
return s;
}
break;
case compiler::utils::GroupCollective::OpKind::ShuffleUp:
case compiler::utils::GroupCollective::OpKind::ShuffleDown:
if (auto s = packetizeSubgroupShuffleUpDown(Ins, *shuffle)) {
return s;
}
break;
}
// We can't packetize all sub-group shuffle-like operations, but we also
// can't vectorize or instantiate them - so provide a diagnostic saying as
Expand Down Expand Up @@ -1099,29 +1156,8 @@ Packetizer::Result Packetizer::Impl::packetizeInstruction(Instruction *Ins) {
break;
}

if (!results.empty()) {
auto packetWidth = results.size();
if (packetWidth == 1) {
Value *vec = results.front();
if (vec != Ins) {
// Only delete if the vectorized value is different from the scalar.
IC.deleteInstructionLater(Ins);
}
vectorizeDI(Ins, vec);
return assign(Ins, vec);
} else {
IC.deleteInstructionLater(Ins);
auto &info = packets[Ins];
auto res = Result(*this, Ins, &info);
auto P = res.createPacket(packetWidth);
for (unsigned i = 0; i < packetWidth; ++i) {
P[i] = results[i];
// TODO CA-3376: vectorize the debug instructions
}
info.numInstances = packetWidth;
++VeczPacketized;
return res;
}
if (auto res = getPacketizationResult(Ins, results, /*update stats*/ true)) {
return res;
}

if (auto *vec = vectorizeInstruction(Ins)) {
Expand Down Expand Up @@ -1587,6 +1623,244 @@ Packetizer::Result Packetizer::Impl::packetizeSubgroupShuffleXor(
return assign(CI, CombinedShuffle);
}

Packetizer::Result Packetizer::Impl::packetizeSubgroupShuffleUpDown(
Instruction *I, compiler::utils::GroupCollective ShuffleUpDown) {
bool IsDown =
ShuffleUpDown.Op == compiler::utils::GroupCollective::OpKind::ShuffleDown;
assert((IsDown || ShuffleUpDown.Op ==
compiler::utils::GroupCollective::OpKind::ShuffleUp) &&
"Invalid shuffle kind");

auto *const CI = cast<CallInst>(I);

// We don't support scalable vectorization of sub-group shuffles.
if (SimdWidth.isScalable()) {
return Packetizer::Result(*this);
}
unsigned const VF = SimdWidth.getFixedValue();

// LHS is 'current' for a down-shuffle, and 'previous' for an up-shuffle.
auto *const LHSOp = CI->getArgOperand(0);
// RHS is 'next' for a down-shuffle, and 'current' for an up-shuffle.
auto *const RHSOp = CI->getArgOperand(1);
auto *const DeltaOp = CI->getArgOperand(2);

auto PackDelta = packetize(DeltaOp);
if (!PackDelta) {
return Packetizer::Result(*this);
}

auto PackLHS = packetize(LHSOp);
if (!PackLHS) {
return Packetizer::Result(*this);
}

auto PackRHS = packetize(RHSOp);
if (!PackRHS) {
return Packetizer::Result(*this);
}

auto *const LHSPackVal = PackLHS.getAsValue();
auto *const RHSPackVal = PackRHS.getAsValue();
assert(LHSPackVal && RHSPackVal &&
LHSPackVal->getType() == RHSPackVal->getType());

// Remember in the example below that the builtins take *deltas* which add
// onto the mux sub-group local ID. Therefore a delta of 2 returns different
// data for each of the mux sub-group elements.
// |----------------------------|----------------------------|
// | shuffle_down(A, X, 2) | shuffle_down(E, I, 2) |
// VF=4 |----------------------------|----------------------------|
// | s(<A,B,C,D>, <X,Y,Z,W>, 2) | s(<E,F,G,H>, <I,J,K,L>, 2) |
// SGIds | 0,1,2,3 | 4,5,6,7 |
// SGIds+D | 2,3,4,5 | 6,7,8,9 |
// MuxSGIds | 0,0,0,0 | 1,1,1,1 |
// |----------------------------|----------------------------|
// M=(SGIds+D)/VF | 0,0,1,1 | 1,1,2,2 |
// V=(SGIds+D)%VF | 2,3,0,1 | 2,3,0,1 |
// |----------------------------|----------------------------|
// M - MuxSGIds | 0,0,1,1 | 0,0,1,1 |
// |----------------------------|----------------------------|
// Shuff[0] | s(<A,B,C,D>, <X,Y,Z,W>, 0) | s(<E,F,G,H>, <I,J,K,L>, 0) |
// Data returned | 0+0 => 0 => <A,B,C,D> | 1+0 => 1 => <E,F,G,H> |
// Shuff[0][V[0]] | <A,B,C,D>[2] = C | <E,F,G,H>[2] = G |
// |----------------------------|----------------------------|
// Shuff[1] | s(<A,B,C,D>, <X,Y,Z,W>, 0) | s(<E,F,G,H>, <I,J,K,L>, 0) |
// Data returned | 0+0 => 0 => <A,B,C,D> | 1+0 => 1 => <E,F,G,H> |
// Shuff[1][V[1]] | <A,B,C,D>[3] = D | <E,F,G,H>[3] = H |
// |----------------------------|----------------------------|
// Shuff[2] | s(<A,B,C,D>, <X,Y,Z,W>, 1) | s(<E,F,G,H>, <I,J,K,L>, 1) |
// Data returned | 0+1 => 1 => <E,F,G,H> | 1+1 => 2 => 0 => <X,Y,Z,W> |
// Shuff[2][V[2]] | <E,F,G,H>[0] = E | <X,Y,Z,W>[0] = X |
// |----------------------------|----------------------------|
// Shuff[3] | s(<A,B,C,D>, <X,Y,Z,W>, 1) | s(<E,F,G,H>, <I,J,K,L>, 1) |
// Data returned | 0+1 => 1 => <E,F,G,H> | 1+1 => 2 => 0 => <X,Y,Z,W> |
// Shuff[3][V[3]] | <E,F,G,H>[1] = F | <X,Y,Z,W>[1] = Y |
// |----------------------------|----------------------------|
// Result | C,D,E,F | G,H,X,Y |
IRBuilder<> B(CI);

// Grab the packetized/vectorized sub-group local IDs
auto *const SubgroupLocalIDFn = Ctx.builtins().getOrDeclareMuxBuiltin(
compiler::utils::eMuxBuiltinGetSubGroupLocalId, *F.getParent(),
{CI->getType()});
assert(SubgroupLocalIDFn);

auto *const SubgroupLocalID =
B.CreateCall(SubgroupLocalIDFn, {}, "sg.local.id");
auto const Builtin =
Ctx.builtins().analyzeBuiltinCall(*SubgroupLocalID, Dimension);

// Vectorize the sub-group local ID
auto *const VecSubgroupLocalID =
vectorizeWorkGroupCall(SubgroupLocalID, Builtin);
if (!VecSubgroupLocalID) {
return Packetizer::Result(*this);
}
VecSubgroupLocalID->setName("vec.sg.local.id");

auto *const DeltaVal = PackDelta.getAsValue();

// The delta is always i32, as is the sub-group local ID. Vectorizing both of
// them should result in the same vector type, with as many elements as the
// vectorization factor.
assert(DeltaVal->getType() == VecSubgroupLocalID->getType() &&
DeltaVal->getType()->isVectorTy() &&
cast<VectorType>(DeltaVal->getType())
->getElementCount()
.getKnownMinValue() == VF &&
"Unexpected vectorization of sub-group shuffle up/down");

// Produce the sum of the sub-group IDs with the 'delta', as per the
// semantics of the builtin.
auto *const IDPlusDelta = IsDown ? B.CreateAdd(VecSubgroupLocalID, DeltaVal)
: B.CreateSub(VecSubgroupLocalID, DeltaVal);

// We need to sanitize the input indices so that they stay within the range
// of one vectorized group.
auto *const VecIdxFactor = ConstantInt::get(SubgroupLocalID->getType(), VF);

// Bring this ID into the range of 'mux' sub-groups by dividing it by the
// vector size. We have to do this differently for 'up' and 'down' shuffles
// because the 'up' shuffles use signed indexing, and we need to round down
// to negative infinity to get the right sub-group delta.
Value *MuxAbsoluteIDs = nullptr;
Value *VecEltIDs = nullptr;
if (IsDown) {
MuxAbsoluteIDs =
B.CreateUDiv(IDPlusDelta, B.CreateVectorSplat(VF, VecIdxFactor));
// And into the range of the vector group
VecEltIDs =
B.CreateURem(IDPlusDelta, B.CreateVectorSplat(VF, VecIdxFactor));
} else {
// Note that shuffling up is more complicated, owing to the signed
// sub-group local IDs.
// The steps are identical to the example outlined above, except both the
// division and modulo operations performed on the sub-group IDs have to
// floor towards negative infinity. That is, we want to see:
// |----------------------------|---------------------------|
// | shuffle_up(A, X, 2) | shuffle_up(E, I, 2) |
// VF=4 |----------------------------|---------------------------|
// | s(<A,B,C,D>, <X,Y,Z,W>, 2) | s(<E,F,G,H>, <I,J,K,L>, 2)|
// SGIds | 0,1,2,3 | 4,5,6,7 |
// SGIds-D | -2,-1,0,1 | 2,3,4,5 |
// MuxSGIds | 0,0,0,0 | 1,1,1,1 |
// |----------------------------|---------------------------|
// both flooring: | | |
// M=(SGIds-D)/VF | -1,-1,0,0 | 0,0,1,1 |
// V=(SGIds-D)%VF | 2,3,0,1 | 2,3,0,1 |
// |----------------------------|---------------------------|
// MuxSGIds - M | 1,1,0,0 | 1,1,0,0 |
// |----------------------------|---------------------------|
//
// We use the following formulae for division and modulo:
// int div_floor(int x, int y) {
// int q = x/y;
// int r = x%y;
// if ((r!=0) && ((r<0) != (y<0))) --q;
// return q;
// }
// int mod_floor(int x, int y) {
// int r = x%y;
// if ((r!=0) && ((r<0) != (y<0))) { r += y; }
// return r;
// }
// We note also that the conditions are equal between the two operations,
// and that the condition is equivalent to:
// if ((r!=0) && ((x ^ y) < 0)) { ... }
// (see https://alive2.llvm.org/ce/z/ebGrdL)
auto *X = IDPlusDelta;
auto *Y = B.CreateVectorSplat(VF, VecIdxFactor);
auto *const Quotient = B.CreateSDiv(X, Y, "quotient");
auto *const Remainder = B.CreateSRem(X, Y, "remainder");

auto *const ArgXor = B.CreateXor(X, Y, "arg.xor");
auto *const One = ConstantInt::get(ArgXor->getType(), 1);
auto *const Zero = ConstantInt::get(ArgXor->getType(), 0);
auto *const ArgSignDifferent =
B.CreateICmpSLT(ArgXor, Zero, "signs.different");
auto *const RemainderIsNotZero =
B.CreateICmpNE(Remainder, Zero, "remainder.nonzero");
auto *const ConditionHolds =
B.CreateAnd(RemainderIsNotZero, ArgSignDifferent, "condition.holds");
auto *const QuotientMinus1 = B.CreateSub(Quotient, One, "quotient.minus.1");
auto *const RemainderPlusY = B.CreateAdd(Remainder, Y, "remainder.plus.y");

MuxAbsoluteIDs = B.CreateSelect(ConditionHolds, QuotientMinus1, Quotient);
VecEltIDs = B.CreateSelect(ConditionHolds, RemainderPlusY, Remainder);
}

// We've produced the 'absolute' mux sub-group local IDs for the data we want
// to access in each shuffle, but we want to get back to 'relative' IDs in
// the form of deltas. Splat the mux sub-group local ID.
auto *const SplatSubgroupLocalID =
B.CreateVectorSplat(VF, SubgroupLocalID, "splat.sg.local.id");
auto *DeltaLHS = MuxAbsoluteIDs;
auto *DeltaRHS = SplatSubgroupLocalID;
if (!IsDown) {
// For 'up' shuffles, we invert the operation as the deltas are implicitly
// negative. See above.
std::swap(DeltaLHS, DeltaRHS);
}
auto *const MuxDeltas =
B.CreateSub(DeltaLHS, DeltaRHS, "mux.sg.local.id.deltas");

auto ShuffleID = Ctx.builtins().getMuxGroupCollective(ShuffleUpDown);
auto *const ShuffleFn = Ctx.builtins().getOrDeclareMuxBuiltin(
ShuffleID, *F.getParent(), {LHSPackVal->getType()});
assert(ShuffleFn);

SmallVector<Value *, 16> Results(VF);
for (unsigned i = 0; i != VF; i++) {
auto *const MuxDelta = B.CreateExtractElement(MuxDeltas, B.getInt32(i));
auto *const Shuffle =
B.CreateCall(ShuffleFn, {LHSPackVal, RHSPackVal, MuxDelta});

Value *Elt = nullptr;
auto *const Idx = B.CreateExtractElement(VecEltIDs, B.getInt32(i));
if (auto *DataVecTy = dyn_cast<VectorType>(LHSOp->getType()); !DataVecTy) {
Elt = B.CreateExtractElement(Shuffle, Idx);
} else {
// For vector data types we need to extract consecutive elements starting
// at the sub-vector whose index is Idx.
Elt = UndefValue::get(DataVecTy);
auto VecWidth = DataVecTy->getElementCount().getFixedValue();
// Idx is the 'base' of the subvector, whose elements are stored
// sequentially from that point.
auto *const VecVecGroupIdx = B.CreateMul(Idx, B.getInt32(VecWidth));
for (unsigned j = 0; j != VecWidth; j++) {
auto *const E = B.CreateExtractElement(
Shuffle, B.CreateAdd(VecVecGroupIdx, B.getInt32(j)));
Elt = B.CreateInsertElement(Elt, E, B.getInt32(j));
}
}
Results[i] = Elt;
}

IC.deleteInstructionLater(CI);
return getPacketizationResult(I, Results);
}

Value *Packetizer::Impl::packetizeMaskVarying(Instruction *I) {
if (auto memop = MemOp::get(I)) {
auto *const mask = memop->getMaskOperand();
Expand Down
Loading

0 comments on commit e2f6294

Please sign in to comment.