From 3b3882b6bdded76eafd8ffef50fb75f29eeb23eb Mon Sep 17 00:00:00 2001 From: Fraser Cormack Date: Mon, 30 Oct 2023 16:58:27 +0000 Subject: [PATCH] [vecz] Packetize sub-group shuffle_(up|down) builtins This extends fixed-width vectorization capabilities to `__mux_sub_group_shuffle_up` and `__mux_sub_group_shuffle_down` builtins. Again, these aren't very efficiently vectorized as we have to perform the shuffle for each work-item in a pseudo-scalarized fashion. --- .../vecz/source/transform/packetizer.cpp | 320 ++++++++++++++++-- .../test/lit/llvm/subgroup_shuffle_down.ll | 186 +++++++++- .../vecz/test/lit/llvm/subgroup_shuffle_up.ll | 216 +++++++++++- 3 files changed, 686 insertions(+), 36 deletions(-) diff --git a/modules/compiler/vecz/source/transform/packetizer.cpp b/modules/compiler/vecz/source/transform/packetizer.cpp index de7654475..071d72773 100644 --- a/modules/compiler/vecz/source/transform/packetizer.cpp +++ b/modules/compiler/vecz/source/transform/packetizer.cpp @@ -152,6 +152,11 @@ class Packetizer::Impl : public Packetizer { /// @return Packetized values. ValuePacket packetizeAndGet(Value *V, unsigned Width); + /// @brief Helper to produce a Result from a Packet + Packetizer::Result getPacketizationResult( + Instruction *I, const SmallVectorImpl &Packet, + bool UpdateStats = false); + /// @brief Packetize the given value from the function, only if it is a /// varying value. Ensures Mask Varying values are handled correctly. /// @@ -233,6 +238,17 @@ class Packetizer::Impl : public Packetizer { /// @return Packetized instructions. Result packetizeSubgroupShuffleXor( Instruction *Ins, compiler::utils::GroupCollective ShuffleXor); + /// @brief Packetize a sub-group shuffle-up or shuffle-down builtin + /// + /// Note - not any shuffle-like operation, but specifically the 'shuffle_up' + /// and 'shuffle_down' builtins. + /// + /// @param[in] Ins Instruction to packetize. + /// @param[in] ShuffleUpDown Shuffle to packetize. + /// + /// @return Packetized instructions. + Result packetizeSubgroupShuffleUpDown( + Instruction *Ins, compiler::utils::GroupCollective ShuffleUpDown); /// @brief Packetize PHI node. /// @@ -803,6 +819,41 @@ PacketRange Packetizer::createPacket(Value *V, unsigned width) { return Result(*this, V, &info).createPacket(width); } +Packetizer::Result Packetizer::Impl::getPacketizationResult( + Instruction *I, const SmallVectorImpl &Packet, bool UpdateStats) { + if (Packet.empty()) { + return Result(*this); + } + auto PacketWidth = Packet.size(); + + // If there's only one value in the packet, we can assign the new packetized + // value to the old instruction directly. + if (PacketWidth == 1) { + Value *Vec = Packet.front(); + if (Vec != I) { + // Only delete if the vectorized value is different from the scalar. + IC.deleteInstructionLater(I); + } + vectorizeDI(I, Vec); + return assign(I, Vec); + } + + // Otherwise we have to create a 'Result' out of the packetized values. + IC.deleteInstructionLater(I); + auto &Info = packets[I]; + auto Res = Result(*this, I, &Info); + auto P = Res.createPacket(PacketWidth); + for (unsigned i = 0; i < PacketWidth; ++i) { + P[i] = Packet[i]; + } + + if (UpdateStats) { + ++VeczPacketized; + } + Info.numInstances = PacketWidth; + return Res; +} + Value *Packetizer::Impl::reduceBranchCond(Value *cond, Instruction *terminator, bool allOf) { // Get the branch condition at its natural packet width @@ -938,6 +989,12 @@ Packetizer::Result Packetizer::Impl::packetize(Value *V) { return s; } break; + case compiler::utils::GroupCollective::OpKind::ShuffleUp: + case compiler::utils::GroupCollective::OpKind::ShuffleDown: + if (auto s = packetizeSubgroupShuffleUpDown(Ins, *shuffle)) { + return s; + } + break; } // We can't packetize all sub-group shuffle-like operations, but we also // can't vectorize or instantiate them - so provide a diagnostic saying as @@ -1099,29 +1156,8 @@ Packetizer::Result Packetizer::Impl::packetizeInstruction(Instruction *Ins) { break; } - if (!results.empty()) { - auto packetWidth = results.size(); - if (packetWidth == 1) { - Value *vec = results.front(); - if (vec != Ins) { - // Only delete if the vectorized value is different from the scalar. - IC.deleteInstructionLater(Ins); - } - vectorizeDI(Ins, vec); - return assign(Ins, vec); - } else { - IC.deleteInstructionLater(Ins); - auto &info = packets[Ins]; - auto res = Result(*this, Ins, &info); - auto P = res.createPacket(packetWidth); - for (unsigned i = 0; i < packetWidth; ++i) { - P[i] = results[i]; - // TODO CA-3376: vectorize the debug instructions - } - info.numInstances = packetWidth; - ++VeczPacketized; - return res; - } + if (auto res = getPacketizationResult(Ins, results, /*update stats*/ true)) { + return res; } if (auto *vec = vectorizeInstruction(Ins)) { @@ -1587,6 +1623,244 @@ Packetizer::Result Packetizer::Impl::packetizeSubgroupShuffleXor( return assign(CI, CombinedShuffle); } +Packetizer::Result Packetizer::Impl::packetizeSubgroupShuffleUpDown( + Instruction *I, compiler::utils::GroupCollective ShuffleUpDown) { + bool IsDown = + ShuffleUpDown.Op == compiler::utils::GroupCollective::OpKind::ShuffleDown; + assert((IsDown || ShuffleUpDown.Op == + compiler::utils::GroupCollective::OpKind::ShuffleUp) && + "Invalid shuffle kind"); + + auto *const CI = cast(I); + + // We don't support scalable vectorization of sub-group shuffles. + if (SimdWidth.isScalable()) { + return Packetizer::Result(*this); + } + unsigned const VF = SimdWidth.getFixedValue(); + + // LHS is 'current' for a down-shuffle, and 'previous' for an up-shuffle. + auto *const LHSOp = CI->getArgOperand(0); + // RHS is 'next' for a down-shuffle, and 'current' for an up-shuffle. + auto *const RHSOp = CI->getArgOperand(1); + auto *const DeltaOp = CI->getArgOperand(2); + + auto PackDelta = packetize(DeltaOp); + if (!PackDelta) { + return Packetizer::Result(*this); + } + + auto PackLHS = packetize(LHSOp); + if (!PackLHS) { + return Packetizer::Result(*this); + } + + auto PackRHS = packetize(RHSOp); + if (!PackRHS) { + return Packetizer::Result(*this); + } + + auto *const LHSPackVal = PackLHS.getAsValue(); + auto *const RHSPackVal = PackRHS.getAsValue(); + assert(LHSPackVal && RHSPackVal && + LHSPackVal->getType() == RHSPackVal->getType()); + + // Remember in the example below that the builtins take *deltas* which add + // onto the mux sub-group local ID. Therefore a delta of 2 returns different + // data for each of the mux sub-group elements. + // |----------------------------|----------------------------| + // | shuffle_down(A, X, 2) | shuffle_down(E, I, 2) | + // VF=4 |----------------------------|----------------------------| + // | s(, , 2) | s(, , 2) | + // SGIds | 0,1,2,3 | 4,5,6,7 | + // SGIds+D | 2,3,4,5 | 6,7,8,9 | + // MuxSGIds | 0,0,0,0 | 1,1,1,1 | + // |----------------------------|----------------------------| + // M=(SGIds+D)/VF | 0,0,1,1 | 1,1,2,2 | + // V=(SGIds+D)%VF | 2,3,0,1 | 2,3,0,1 | + // |----------------------------|----------------------------| + // M - MuxSGIds | 0,0,1,1 | 0,0,1,1 | + // |----------------------------|----------------------------| + // Shuff[0] | s(, , 0) | s(, , 0) | + // Data returned | 0+0 => 0 => | 1+0 => 1 => | + // Shuff[0][V[0]] | [2] = C | [2] = G | + // |----------------------------|----------------------------| + // Shuff[1] | s(, , 0) | s(, , 0) | + // Data returned | 0+0 => 0 => | 1+0 => 1 => | + // Shuff[1][V[1]] | [3] = D | [3] = H | + // |----------------------------|----------------------------| + // Shuff[2] | s(, , 1) | s(, , 1) | + // Data returned | 0+1 => 1 => | 1+1 => 2 => 0 => | + // Shuff[2][V[2]] | [0] = E | [0] = X | + // |----------------------------|----------------------------| + // Shuff[3] | s(, , 1) | s(, , 1) | + // Data returned | 0+1 => 1 => | 1+1 => 2 => 0 => | + // Shuff[3][V[3]] | [1] = F | [1] = Y | + // |----------------------------|----------------------------| + // Result | C,D,E,F | G,H,X,Y | + IRBuilder<> B(CI); + + // Grab the packetized/vectorized sub-group local IDs + auto *const SubgroupLocalIDFn = Ctx.builtins().getOrDeclareMuxBuiltin( + compiler::utils::eMuxBuiltinGetSubGroupLocalId, *F.getParent(), + {CI->getType()}); + assert(SubgroupLocalIDFn); + + auto *const SubgroupLocalID = + B.CreateCall(SubgroupLocalIDFn, {}, "sg.local.id"); + auto const Builtin = + Ctx.builtins().analyzeBuiltinCall(*SubgroupLocalID, Dimension); + + // Vectorize the sub-group local ID + auto *const VecSubgroupLocalID = + vectorizeWorkGroupCall(SubgroupLocalID, Builtin); + if (!VecSubgroupLocalID) { + return Packetizer::Result(*this); + } + VecSubgroupLocalID->setName("vec.sg.local.id"); + + auto *const DeltaVal = PackDelta.getAsValue(); + + // The delta is always i32, as is the sub-group local ID. Vectorizing both of + // them should result in the same vector type, with as many elements as the + // vectorization factor. + assert(DeltaVal->getType() == VecSubgroupLocalID->getType() && + DeltaVal->getType()->isVectorTy() && + cast(DeltaVal->getType()) + ->getElementCount() + .getKnownMinValue() == VF && + "Unexpected vectorization of sub-group shuffle up/down"); + + // Produce the sum of the sub-group IDs with the 'delta', as per the + // semantics of the builtin. + auto *const IDPlusDelta = IsDown ? B.CreateAdd(VecSubgroupLocalID, DeltaVal) + : B.CreateSub(VecSubgroupLocalID, DeltaVal); + + // We need to sanitize the input indices so that they stay within the range + // of one vectorized group. + auto *const VecIdxFactor = ConstantInt::get(SubgroupLocalID->getType(), VF); + + // Bring this ID into the range of 'mux' sub-groups by dividing it by the + // vector size. We have to do this differently for 'up' and 'down' shuffles + // because the 'up' shuffles use signed indexing, and we need to round down + // to negative infinity to get the right sub-group delta. + Value *MuxAbsoluteIDs = nullptr; + Value *VecEltIDs = nullptr; + if (IsDown) { + MuxAbsoluteIDs = + B.CreateUDiv(IDPlusDelta, B.CreateVectorSplat(VF, VecIdxFactor)); + // And into the range of the vector group + VecEltIDs = + B.CreateURem(IDPlusDelta, B.CreateVectorSplat(VF, VecIdxFactor)); + } else { + // Note that shuffling up is more complicated, owing to the signed + // sub-group local IDs. + // The steps are identical to the example outlined above, except both the + // division and modulo operations performed on the sub-group IDs have to + // floor towards negative infinity. That is, we want to see: + // |----------------------------|---------------------------| + // | shuffle_up(A, X, 2) | shuffle_up(E, I, 2) | + // VF=4 |----------------------------|---------------------------| + // | s(, , 2) | s(, , 2)| + // SGIds | 0,1,2,3 | 4,5,6,7 | + // SGIds-D | -2,-1,0,1 | 2,3,4,5 | + // MuxSGIds | 0,0,0,0 | 1,1,1,1 | + // |----------------------------|---------------------------| + // both flooring: | | | + // M=(SGIds-D)/VF | -1,-1,0,0 | 0,0,1,1 | + // V=(SGIds-D)%VF | 2,3,0,1 | 2,3,0,1 | + // |----------------------------|---------------------------| + // MuxSGIds - M | 1,1,0,0 | 1,1,0,0 | + // |----------------------------|---------------------------| + // + // We use the following formulae for division and modulo: + // int div_floor(int x, int y) { + // int q = x/y; + // int r = x%y; + // if ((r!=0) && ((r<0) != (y<0))) --q; + // return q; + // } + // int mod_floor(int x, int y) { + // int r = x%y; + // if ((r!=0) && ((r<0) != (y<0))) { r += y; } + // return r; + // } + // We note also that the conditions are equal between the two operations, + // and that the condition is equivalent to: + // if ((r!=0) && ((x ^ y) < 0)) { ... } + // (see https://alive2.llvm.org/ce/z/ebGrdL) + auto *X = IDPlusDelta; + auto *Y = B.CreateVectorSplat(VF, VecIdxFactor); + auto *const Quotient = B.CreateSDiv(X, Y, "quotient"); + auto *const Remainder = B.CreateSRem(X, Y, "remainder"); + + auto *const ArgXor = B.CreateXor(X, Y, "arg.xor"); + auto *const One = ConstantInt::get(ArgXor->getType(), 1); + auto *const Zero = ConstantInt::get(ArgXor->getType(), 0); + auto *const ArgSignDifferent = + B.CreateICmpSLT(ArgXor, Zero, "signs.different"); + auto *const RemainderIsNotZero = + B.CreateICmpNE(Remainder, Zero, "remainder.nonzero"); + auto *const ConditionHolds = + B.CreateAnd(RemainderIsNotZero, ArgSignDifferent, "condition.holds"); + auto *const QuotientMinus1 = B.CreateSub(Quotient, One, "quotient.minus.1"); + auto *const RemainderPlusY = B.CreateAdd(Remainder, Y, "remainder.plus.y"); + + MuxAbsoluteIDs = B.CreateSelect(ConditionHolds, QuotientMinus1, Quotient); + VecEltIDs = B.CreateSelect(ConditionHolds, RemainderPlusY, Remainder); + } + + // We've produced the 'absolute' mux sub-group local IDs for the data we want + // to access in each shuffle, but we want to get back to 'relative' IDs in + // the form of deltas. Splat the mux sub-group local ID. + auto *const SplatSubgroupLocalID = + B.CreateVectorSplat(VF, SubgroupLocalID, "splat.sg.local.id"); + auto *DeltaLHS = MuxAbsoluteIDs; + auto *DeltaRHS = SplatSubgroupLocalID; + if (!IsDown) { + // For 'up' shuffles, we invert the operation as the deltas are implicitly + // negative. See above. + std::swap(DeltaLHS, DeltaRHS); + } + auto *const MuxDeltas = + B.CreateSub(DeltaLHS, DeltaRHS, "mux.sg.local.id.deltas"); + + auto ShuffleID = Ctx.builtins().getMuxGroupCollective(ShuffleUpDown); + auto *const ShuffleFn = Ctx.builtins().getOrDeclareMuxBuiltin( + ShuffleID, *F.getParent(), {LHSPackVal->getType()}); + assert(ShuffleFn); + + SmallVector Results(VF); + for (unsigned i = 0; i != VF; i++) { + auto *const MuxDelta = B.CreateExtractElement(MuxDeltas, B.getInt32(i)); + auto *const Shuffle = + B.CreateCall(ShuffleFn, {LHSPackVal, RHSPackVal, MuxDelta}); + + Value *Elt = nullptr; + auto *const Idx = B.CreateExtractElement(VecEltIDs, B.getInt32(i)); + if (auto *DataVecTy = dyn_cast(LHSOp->getType()); !DataVecTy) { + Elt = B.CreateExtractElement(Shuffle, Idx); + } else { + // For vector data types we need to extract consecutive elements starting + // at the sub-vector whose index is Idx. + Elt = UndefValue::get(DataVecTy); + auto VecWidth = DataVecTy->getElementCount().getFixedValue(); + // Idx is the 'base' of the subvector, whose elements are stored + // sequentially from that point. + auto *const VecVecGroupIdx = B.CreateMul(Idx, B.getInt32(VecWidth)); + for (unsigned j = 0; j != VecWidth; j++) { + auto *const E = B.CreateExtractElement( + Shuffle, B.CreateAdd(VecVecGroupIdx, B.getInt32(j))); + Elt = B.CreateInsertElement(Elt, E, B.getInt32(j)); + } + } + Results[i] = Elt; + } + + IC.deleteInstructionLater(CI); + return getPacketizationResult(I, Results); +} + Value *Packetizer::Impl::packetizeMaskVarying(Instruction *I) { if (auto memop = MemOp::get(I)) { auto *const mask = memop->getMaskOperand(); diff --git a/modules/compiler/vecz/test/lit/llvm/subgroup_shuffle_down.ll b/modules/compiler/vecz/test/lit/llvm/subgroup_shuffle_down.ll index a3566c8e6..3e90d729f 100644 --- a/modules/compiler/vecz/test/lit/llvm/subgroup_shuffle_down.ll +++ b/modules/compiler/vecz/test/lit/llvm/subgroup_shuffle_down.ll @@ -20,17 +20,187 @@ target triple = "spir64-unknown-unknown" target datalayout = "e-p:64:64:64-m:e-i64:64-f80:128-n8:16:32:64-S128" -; CHECK: Could not packetize sub-group shuffle %shuffle_down -define spir_kernel void @kernel(ptr %in, ptr %out) { +; CHECK-LABEL: define spir_kernel void @__vecz_v4_kernel(ptr %lhsptr, ptr %rhsptr, ptr %out) +; CHECK: [[LHS:%.*]] = load <4 x float>, ptr %arrayidx.lhs, align 4 +; CHECK: [[RHS:%.*]] = load <4 x float>, ptr %arrayidx.rhs, align 4 + +; CHECK: [[DELTAS:%.*]] = add <4 x i32> {{%.*}}, +; CHECK: [[MUXIDS:%.*]] = udiv <4 x i32> [[DELTAS]], +; CHECK: [[VECELTS:%.*]] = urem <4 x i32> [[DELTAS]], +; CHECK: [[MUXDELTAS:%.*]] = sub <4 x i32> [[MUXIDS]], {{%.*}} + +; CHECK: [[DELTA0:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 0 +; CHECK: [[SHUFF0:%.*]] = call <4 x float> @__mux_sub_group_shuffle_down_v4f32( +; CHECK-SAME: <4 x float> [[LHS]], <4 x float> [[RHS]], i32 [[DELTA0]]) +; CHECK: [[VECIDX0:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 0 +; CHECK: [[ELT0:%.*]] = extractelement <4 x float> [[SHUFF0]], i32 [[VECIDX0]] + +; CHECK: [[DELTA1:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 1 +; CHECK: [[SHUFF1:%.*]] = call <4 x float> @__mux_sub_group_shuffle_down_v4f32( +; CHECK-SAME: <4 x float> [[LHS]], <4 x float> [[RHS]], i32 [[DELTA1]]) +; CHECK: [[VECIDX1:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 1 +; CHECK: [[ELT1:%.*]] = extractelement <4 x float> [[SHUFF1]], i32 [[VECIDX1]] + +; CHECK: [[DELTA2:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 2 +; CHECK: [[SHUFF2:%.*]] = call <4 x float> @__mux_sub_group_shuffle_down_v4f32( +; CHECK-SAME: <4 x float> [[LHS]], <4 x float> [[RHS]], i32 [[DELTA2]]) +; CHECK: [[VECIDX2:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 2 +; CHECK: [[ELT2:%.*]] = extractelement <4 x float> [[SHUFF2]], i32 [[VECIDX2]] + +; CHECK: [[DELTA3:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 3 +; CHECK: [[SHUFF3:%.*]] = call <4 x float> @__mux_sub_group_shuffle_down_v4f32( +; CHECK-SAME: <4 x float> [[LHS]], <4 x float> [[RHS]], i32 [[DELTA3]]) +; CHECK: [[VECIDX3:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 3 +; CHECK: [[ELT3:%.*]] = extractelement <4 x float> [[SHUFF3]], i32 [[VECIDX3]] +define spir_kernel void @kernel(ptr %lhsptr, ptr %rhsptr, ptr %out) { + %gid = tail call i64 @__mux_get_global_id(i32 0) + %arrayidx.lhs = getelementptr inbounds float, ptr %lhsptr, i64 %gid + %lhs = load float, ptr %arrayidx.lhs, align 4 + %arrayidx.rhs = getelementptr inbounds float, ptr %rhsptr, i64 %gid + %rhs = load float, ptr %arrayidx.rhs, align 4 + %shuffle_up = call float @__mux_sub_group_shuffle_down_f32(float %lhs, float %rhs, i32 1) + %arrayidx.out = getelementptr inbounds float, ptr %out, i64 %gid + store float %shuffle_up, ptr %arrayidx.out, align 8 + ret void +} + +; CHECK-LABEL: define spir_kernel void @__vecz_v4_kernel_vec_data(ptr %lhsptr, ptr %rhsptr, ptr %out) +; CHECK: [[DELTAS:%.*]] = add <4 x i32> {{%.*}}, +; CHECK: [[MUXIDS:%.*]] = udiv <4 x i32> [[DELTAS]], +; CHECK: [[VECELTS:%.*]] = urem <4 x i32> [[DELTAS]], +; CHECK: [[MUXDELTAS:%.*]] = sub <4 x i32> [[MUXIDS]], {{%.*}} + +; CHECK: [[DELTA0:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 0 +; CHECK: [[SHUFF0:%.*]] = call <16 x i8> @__mux_sub_group_shuffle_down_v16i8( +; CHECK-SAME: <16 x i8> [[LHS:%.*]], <16 x i8> [[RHS:%.*]], i32 [[DELTA0]]) +; CHECK: [[SUBVECIDX0:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 0 +; CHECK: [[ELTBASE0:%.*]] = mul i32 [[SUBVECIDX0]], 4 +; CHECK: [[VECIDX00:%.*]] = add i32 [[ELTBASE0]], 0 +; CHECK: [[ELT00:%.*]] = extractelement <16 x i8> [[SHUFF0]], i32 [[VECIDX00]] +; CHECK: [[VEC00:%.*]] = insertelement <4 x i8> undef, i8 [[ELT00]], i32 0 +; CHECK: [[VECIDX01:%.*]] = add i32 [[ELTBASE0]], 1 +; CHECK: [[ELT01:%.*]] = extractelement <16 x i8> [[SHUFF0]], i32 [[VECIDX01]] +; CHECK: [[VEC01:%.*]] = insertelement <4 x i8> [[VEC00]], i8 [[ELT01]], i32 1 +; CHECK: [[VECIDX02:%.*]] = add i32 [[ELTBASE0]], 2 +; CHECK: [[ELT02:%.*]] = extractelement <16 x i8> [[SHUFF0]], i32 [[VECIDX02]] +; CHECK: [[VEC02:%.*]] = insertelement <4 x i8> [[VEC01]], i8 [[ELT02]], i32 2 +; CHECK: [[VECIDX03:%.*]] = add i32 [[ELTBASE0]], 3 +; CHECK: [[ELT03:%.*]] = extractelement <16 x i8> [[SHUFF0]], i32 [[VECIDX03]] +; CHECK: [[VEC03:%.*]] = insertelement <4 x i8> [[VEC02]], i8 [[ELT03]], i32 3 + +; CHECK: [[DELTA1:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 1 +; CHECK: [[SHUFF1:%.*]] = call <16 x i8> @__mux_sub_group_shuffle_down_v16i8( +; CHECK-SAME: <16 x i8> [[LHS]], <16 x i8> [[RHS]], i32 [[DELTA1]]) +; CHECK: [[SUBVECIDX1:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 1 +; CHECK: [[ELTBASE1:%.*]] = mul i32 [[SUBVECIDX1]], 4 +; CHECK: [[VECIDX10:%.*]] = add i32 [[ELTBASE1]], 0 +; CHECK: [[ELT10:%.*]] = extractelement <16 x i8> [[SHUFF1]], i32 [[VECIDX10]] +; CHECK: [[VEC10:%.*]] = insertelement <4 x i8> undef, i8 [[ELT10]], i32 0 +; CHECK: [[VECIDX11:%.*]] = add i32 [[ELTBASE1]], 1 +; CHECK: [[ELT11:%.*]] = extractelement <16 x i8> [[SHUFF1]], i32 [[VECIDX11]] +; CHECK: [[VEC11:%.*]] = insertelement <4 x i8> [[VEC10]], i8 [[ELT11]], i32 1 +; CHECK: [[VECIDX12:%.*]] = add i32 [[ELTBASE1]], 2 +; CHECK: [[ELT12:%.*]] = extractelement <16 x i8> [[SHUFF1]], i32 [[VECIDX12]] +; CHECK: [[VEC12:%.*]] = insertelement <4 x i8> [[VEC11]], i8 [[ELT12]], i32 2 +; CHECK: [[VECIDX13:%.*]] = add i32 [[ELTBASE1]], 3 +; CHECK: [[ELT13:%.*]] = extractelement <16 x i8> [[SHUFF1]], i32 [[VECIDX13]] +; CHECK: [[VEC13:%.*]] = insertelement <4 x i8> [[VEC12]], i8 [[ELT13]], i32 3 + +; CHECK: [[DELTA2:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 2 +; CHECK: [[SHUFF2:%.*]] = call <16 x i8> @__mux_sub_group_shuffle_down_v16i8( +; CHECK-SAME: <16 x i8> [[LHS]], <16 x i8> [[RHS]], i32 [[DELTA2]]) +; CHECK: [[SUBVECIDX2:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 2 +; CHECK: [[ELTBASE2:%.*]] = mul i32 [[SUBVECIDX2]], 4 +; CHECK: [[VECIDX20:%.*]] = add i32 [[ELTBASE2]], 0 +; CHECK: [[ELT20:%.*]] = extractelement <16 x i8> [[SHUFF2]], i32 [[VECIDX20]] +; CHECK: [[VEC20:%.*]] = insertelement <4 x i8> undef, i8 [[ELT20]], i32 0 +; CHECK: [[VECIDX21:%.*]] = add i32 [[ELTBASE2]], 1 +; CHECK: [[ELT21:%.*]] = extractelement <16 x i8> [[SHUFF2]], i32 [[VECIDX21]] +; CHECK: [[VEC21:%.*]] = insertelement <4 x i8> [[VEC20]], i8 [[ELT21]], i32 1 +; CHECK: [[VECIDX22:%.*]] = add i32 [[ELTBASE2]], 2 +; CHECK: [[ELT22:%.*]] = extractelement <16 x i8> [[SHUFF2]], i32 [[VECIDX22]] +; CHECK: [[VEC22:%.*]] = insertelement <4 x i8> [[VEC21]], i8 [[ELT22]], i32 2 +; CHECK: [[VECIDX23:%.*]] = add i32 [[ELTBASE2]], 3 +; CHECK: [[ELT23:%.*]] = extractelement <16 x i8> [[SHUFF2]], i32 [[VECIDX23]] +; CHECK: [[VEC23:%.*]] = insertelement <4 x i8> [[VEC22]], i8 [[ELT23]], i32 3 + +; CHECK: [[DELTA3:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 3 +; CHECK: [[SHUFF3:%.*]] = call <16 x i8> @__mux_sub_group_shuffle_down_v16i8( +; CHECK-SAME: <16 x i8> [[LHS]], <16 x i8> [[RHS]], i32 [[DELTA3]]) +; CHECK: [[SUBVECIDX3:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 3 +; CHECK: [[ELTBASE3:%.*]] = mul i32 [[SUBVECIDX3]], 4 +; CHECK: [[VECIDX30:%.*]] = add i32 [[ELTBASE3]], 0 +; CHECK: [[ELT30:%.*]] = extractelement <16 x i8> [[SHUFF3]], i32 [[VECIDX30]] +; CHECK: [[VEC30:%.*]] = insertelement <4 x i8> undef, i8 [[ELT30]], i32 0 +; CHECK: [[VECIDX31:%.*]] = add i32 [[ELTBASE3]], 1 +; CHECK: [[ELT31:%.*]] = extractelement <16 x i8> [[SHUFF3]], i32 [[VECIDX31]] +; CHECK: [[VEC31:%.*]] = insertelement <4 x i8> [[VEC30]], i8 [[ELT31]], i32 1 +; CHECK: [[VECIDX32:%.*]] = add i32 [[ELTBASE3]], 2 +; CHECK: [[ELT32:%.*]] = extractelement <16 x i8> [[SHUFF3]], i32 [[VECIDX32]] +; CHECK: [[VEC32:%.*]] = insertelement <4 x i8> [[VEC31]], i8 [[ELT32]], i32 2 +; CHECK: [[VECIDX33:%.*]] = add i32 [[ELTBASE3]], 3 +; CHECK: [[ELT33:%.*]] = extractelement <16 x i8> [[SHUFF3]], i32 [[VECIDX33]] +; CHECK: [[VEC33:%.*]] = insertelement <4 x i8> [[VEC32]], i8 [[ELT33]], i32 3 +define spir_kernel void @kernel_vec_data(ptr %lhsptr, ptr %rhsptr, ptr %out) { + %gid = tail call i64 @__mux_get_global_id(i32 0) + %arrayidx.lhs = getelementptr inbounds <4 x i8>, ptr %lhsptr, i64 %gid + %lhs = load <4 x i8>, ptr %arrayidx.lhs, align 4 + %arrayidx.rhs = getelementptr inbounds <4 x i8>, ptr %rhsptr, i64 %gid + %rhs = load <4 x i8>, ptr %arrayidx.rhs, align 4 + %shuffle_up = call <4 x i8> @__mux_sub_group_shuffle_down_v4i8(<4 x i8> %lhs, <4 x i8> %rhs, i32 2) + %arrayidx.out = getelementptr inbounds <4 x i8>, ptr %out, i64 %gid + store <4 x i8> %shuffle_up, ptr %arrayidx.out, align 4 + ret void +} + +; CHECK-LABEL: define spir_kernel void @__vecz_v4_kernel_varying_delta(ptr %lhsptr, ptr %rhsptr, ptr %deltaptr, ptr %out) +; CHECK: [[LHS:%.*]] = load <4 x float>, ptr %arrayidx.lhs, align 4 +; CHECK: [[RHS:%.*]] = load <4 x float>, ptr %arrayidx.rhs, align 4 +; CHECK: [[DELTALD:%.*]] = load <4 x i32>, ptr %arrayidx.deltas, align 4 + +; CHECK: [[DELTAS:%.*]] = add <4 x i32> {{%.*}}, [[DELTALD]] +; CHECK: [[MUXIDS:%.*]] = udiv <4 x i32> [[DELTAS]], +; CHECK: [[VECELTS:%.*]] = urem <4 x i32> [[DELTAS]], +; CHECK: [[MUXDELTAS:%.*]] = sub <4 x i32> [[MUXIDS]], {{%.*}} + +; CHECK: [[DELTA0:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 0 +; CHECK: [[SHUFF0:%.*]] = call <4 x float> @__mux_sub_group_shuffle_down_v4f32( +; CHECK-SAME: <4 x float> [[LHS]], <4 x float> [[RHS]], i32 [[DELTA0]]) +; CHECK: [[VECIDX0:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 0 +; CHECK: [[ELT0:%.*]] = extractelement <4 x float> [[SHUFF0]], i32 [[VECIDX0]] + +; CHECK: [[DELTA1:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 1 +; CHECK: [[SHUFF1:%.*]] = call <4 x float> @__mux_sub_group_shuffle_down_v4f32( +; CHECK-SAME: <4 x float> [[LHS]], <4 x float> [[RHS]], i32 [[DELTA1]]) +; CHECK: [[VECIDX1:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 1 +; CHECK: [[ELT1:%.*]] = extractelement <4 x float> [[SHUFF1]], i32 [[VECIDX1]] + +; CHECK: [[DELTA2:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 2 +; CHECK: [[SHUFF2:%.*]] = call <4 x float> @__mux_sub_group_shuffle_down_v4f32( +; CHECK-SAME: <4 x float> [[LHS]], <4 x float> [[RHS]], i32 [[DELTA2]]) +; CHECK: [[VECIDX2:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 2 +; CHECK: [[ELT2:%.*]] = extractelement <4 x float> [[SHUFF2]], i32 [[VECIDX2]] + +; CHECK: [[DELTA3:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 3 +; CHECK: [[SHUFF3:%.*]] = call <4 x float> @__mux_sub_group_shuffle_down_v4f32( +; CHECK-SAME: <4 x float> [[LHS]], <4 x float> [[RHS]], i32 [[DELTA3]]) +; CHECK: [[VECIDX3:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 3 +; CHECK: [[ELT3:%.*]] = extractelement <4 x float> [[SHUFF3]], i32 [[VECIDX3]] +define spir_kernel void @kernel_varying_delta(ptr %lhsptr, ptr %rhsptr, ptr %deltaptr, ptr %out) { %gid = tail call i64 @__mux_get_global_id(i32 0) - %arrayidx.in = getelementptr inbounds i8, ptr %in, i64 %gid - %val = load i8, ptr %arrayidx.in, align 8 - %shuffle_down = call i8 @__mux_sub_group_shuffle_down_i8(i8 %val, i8 %val, i32 1) - %arrayidx.out = getelementptr inbounds i8, ptr %out, i64 %gid - store i8 %shuffle_down, ptr %arrayidx.out, align 8 + %arrayidx.lhs = getelementptr inbounds float, ptr %lhsptr, i64 %gid + %lhs = load float, ptr %arrayidx.lhs, align 4 + %arrayidx.rhs = getelementptr inbounds float, ptr %rhsptr, i64 %gid + %rhs = load float, ptr %arrayidx.rhs, align 4 + %arrayidx.deltas = getelementptr inbounds i32, ptr %deltaptr, i64 %gid + %delta = load i32, ptr %arrayidx.deltas, align 4 + %shuffle_up = call float @__mux_sub_group_shuffle_down_f32(float %lhs, float %rhs, i32 %delta) + %arrayidx.out = getelementptr inbounds float, ptr %out, i64 %gid + store float %shuffle_up, ptr %arrayidx.out, align 8 ret void } declare i64 @__mux_get_global_id(i32) -declare i8 @__mux_sub_group_shuffle_down_i8(i8 %curr, i8 %next, i32 %delta) +declare float @__mux_sub_group_shuffle_down_f32(float %prev, float %curr, i32 %delta) +declare <4 x i8> @__mux_sub_group_shuffle_down_v4i8(<4 x i8> %prev, <4 x i8> %curr, i32 %delta) diff --git a/modules/compiler/vecz/test/lit/llvm/subgroup_shuffle_up.ll b/modules/compiler/vecz/test/lit/llvm/subgroup_shuffle_up.ll index e27365e85..a3e645e88 100644 --- a/modules/compiler/vecz/test/lit/llvm/subgroup_shuffle_up.ll +++ b/modules/compiler/vecz/test/lit/llvm/subgroup_shuffle_up.ll @@ -20,12 +20,217 @@ target triple = "spir64-unknown-unknown" target datalayout = "e-p:64:64:64-m:e-i64:64-f80:128-n8:16:32:64-S128" -; CHECK: Could not packetize sub-group shuffle %shuffle_up -define spir_kernel void @kernel(ptr %in, ptr %out) { +; CHECK-LABEL: define spir_kernel void @__vecz_v4_kernel(ptr %lhsptr, ptr %rhsptr, ptr %out) +; CHECK: [[LHS:%.*]] = load <4 x float>, ptr %arrayidx.lhs, align 4 +; CHECK: [[RHS:%.*]] = load <4 x float>, ptr %arrayidx.rhs, align 4 + +; CHECK: [[DELTAS:%.*]] = sub <4 x i32> {{%.*}}, +; CHECK: [[QUOTIENT:%.*]] = sdiv <4 x i32> [[DELTAS]], +; CHECK: [[REMAINDER:%.*]] = srem <4 x i32> [[DELTAS]], + +; CHECK: [[ARGXOR:%.*]] = xor <4 x i32> [[DELTAS]], +; CHECK: [[SIGNDIFF:%.*]] = icmp slt <4 x i32> [[ARGXOR]], zeroinitializer +; CHECK: [[REMNONZERO:%.*]] = icmp ne <4 x i32> [[REMAINDER]], zeroinitializer +; CHECK: [[CONDITION:%.*]] = and <4 x i1> [[REMNONZERO]], [[SIGNDIFF]] + +; CHECK: [[MIN1:%.*]] = sub <4 x i32> [[QUOTIENT]], +; CHECK: [[PLUSR:%.*]] = add <4 x i32> [[REMAINDER]], + +; CHECK: [[MUXIDS:%.*]] = select <4 x i1> [[CONDITION]], <4 x i32> [[MIN1]], <4 x i32> [[QUOTIENT]] +; CHECK: [[VECELTS:%.*]] = select <4 x i1> [[CONDITION]], <4 x i32> [[PLUSR]], <4 x i32> [[REMAINDER]] + +; CHECK: [[MUXDELTAS:%.*]] = sub <4 x i32> {{%.*}}, [[MUXIDS]] + +; CHECK: [[DELTA0:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 0 +; CHECK: [[SHUFF0:%.*]] = call <4 x float> @__mux_sub_group_shuffle_up_v4f32( +; CHECK-SAME: <4 x float> [[LHS]], <4 x float> [[RHS]], i32 [[DELTA0]]) +; CHECK: [[VECIDX0:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 0 +; CHECK: [[ELT0:%.*]] = extractelement <4 x float> [[SHUFF0]], i32 [[VECIDX0]] + +; CHECK: [[DELTA1:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 1 +; CHECK: [[SHUFF1:%.*]] = call <4 x float> @__mux_sub_group_shuffle_up_v4f32( +; CHECK-SAME: <4 x float> [[LHS]], <4 x float> [[RHS]], i32 [[DELTA1]]) +; CHECK: [[VECIDX1:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 1 +; CHECK: [[ELT1:%.*]] = extractelement <4 x float> [[SHUFF1]], i32 [[VECIDX1]] + +; CHECK: [[DELTA2:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 2 +; CHECK: [[SHUFF2:%.*]] = call <4 x float> @__mux_sub_group_shuffle_up_v4f32( +; CHECK-SAME: <4 x float> [[LHS]], <4 x float> [[RHS]], i32 [[DELTA2]]) +; CHECK: [[VECIDX2:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 2 +; CHECK: [[ELT2:%.*]] = extractelement <4 x float> [[SHUFF2]], i32 [[VECIDX2]] + +; CHECK: [[DELTA3:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 3 +; CHECK: [[SHUFF3:%.*]] = call <4 x float> @__mux_sub_group_shuffle_up_v4f32( +; CHECK-SAME: <4 x float> [[LHS]], <4 x float> [[RHS]], i32 [[DELTA3]]) +; CHECK: [[VECIDX3:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 3 +; CHECK: [[ELT3:%.*]] = extractelement <4 x float> [[SHUFF3]], i32 [[VECIDX3]] +define spir_kernel void @kernel(ptr %lhsptr, ptr %rhsptr, ptr %out) { + %gid = tail call i64 @__mux_get_global_id(i32 0) + %arrayidx.lhs = getelementptr inbounds float, ptr %lhsptr, i64 %gid + %lhs = load float, ptr %arrayidx.lhs, align 4 + %arrayidx.rhs = getelementptr inbounds float, ptr %rhsptr, i64 %gid + %rhs = load float, ptr %arrayidx.rhs, align 4 + %shuffle_up = call float @__mux_sub_group_shuffle_up_f32(float %lhs, float %rhs, i32 1) + %arrayidx.out = getelementptr inbounds float, ptr %out, i64 %gid + store float %shuffle_up, ptr %arrayidx.out, align 8 + ret void +} + +; CHECK-LABEL: define spir_kernel void @__vecz_v4_kernel_vec_data(ptr %lhsptr, ptr %rhsptr, ptr %out) +; CHECK: [[DELTAS:%.*]] = sub <4 x i32> {{%.*}}, +; CHECK: [[QUOTIENT:%.*]] = sdiv <4 x i32> [[DELTAS]], +; CHECK: [[REMAINDER:%.*]] = srem <4 x i32> [[DELTAS]], + +; CHECK: [[ARGXOR:%.*]] = xor <4 x i32> [[DELTAS]], +; CHECK: [[SIGNDIFF:%.*]] = icmp slt <4 x i32> [[ARGXOR]], zeroinitializer +; CHECK: [[REMNONZERO:%.*]] = icmp ne <4 x i32> [[REMAINDER]], zeroinitializer +; CHECK: [[CONDITION:%.*]] = and <4 x i1> [[REMNONZERO]], [[SIGNDIFF]] + +; CHECK: [[MIN1:%.*]] = sub <4 x i32> [[QUOTIENT]], +; CHECK: [[PLUSR:%.*]] = add <4 x i32> [[REMAINDER]], + +; CHECK: [[MUXIDS:%.*]] = select <4 x i1> [[CONDITION]], <4 x i32> [[MIN1]], <4 x i32> [[QUOTIENT]] +; CHECK: [[VECELTS:%.*]] = select <4 x i1> [[CONDITION]], <4 x i32> [[PLUSR]], <4 x i32> [[REMAINDER]] + +; CHECK: [[MUXDELTAS:%.*]] = sub <4 x i32> {{%.*}}, [[MUXIDS]] + +; CHECK: [[DELTA0:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 0 +; CHECK: [[SHUFF0:%.*]] = call <16 x i8> @__mux_sub_group_shuffle_up_v16i8( +; CHECK-SAME: <16 x i8> [[LHS:%.*]], <16 x i8> [[RHS:%.*]], i32 [[DELTA0]]) +; CHECK: [[SUBVECIDX0:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 0 +; CHECK: [[ELTBASE0:%.*]] = mul i32 [[SUBVECIDX0]], 4 +; CHECK: [[VECIDX00:%.*]] = add i32 [[ELTBASE0]], 0 +; CHECK: [[ELT00:%.*]] = extractelement <16 x i8> [[SHUFF0]], i32 [[VECIDX00]] +; CHECK: [[VEC00:%.*]] = insertelement <4 x i8> undef, i8 [[ELT00]], i32 0 +; CHECK: [[VECIDX01:%.*]] = add i32 [[ELTBASE0]], 1 +; CHECK: [[ELT01:%.*]] = extractelement <16 x i8> [[SHUFF0]], i32 [[VECIDX01]] +; CHECK: [[VEC01:%.*]] = insertelement <4 x i8> [[VEC00]], i8 [[ELT01]], i32 1 +; CHECK: [[VECIDX02:%.*]] = add i32 [[ELTBASE0]], 2 +; CHECK: [[ELT02:%.*]] = extractelement <16 x i8> [[SHUFF0]], i32 [[VECIDX02]] +; CHECK: [[VEC02:%.*]] = insertelement <4 x i8> [[VEC01]], i8 [[ELT02]], i32 2 +; CHECK: [[VECIDX03:%.*]] = add i32 [[ELTBASE0]], 3 +; CHECK: [[ELT03:%.*]] = extractelement <16 x i8> [[SHUFF0]], i32 [[VECIDX03]] +; CHECK: [[VEC03:%.*]] = insertelement <4 x i8> [[VEC02]], i8 [[ELT03]], i32 3 + +; CHECK: [[DELTA1:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 1 +; CHECK: [[SHUFF1:%.*]] = call <16 x i8> @__mux_sub_group_shuffle_up_v16i8( +; CHECK-SAME: <16 x i8> [[LHS]], <16 x i8> [[RHS]], i32 [[DELTA1]]) +; CHECK: [[SUBVECIDX1:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 1 +; CHECK: [[ELTBASE1:%.*]] = mul i32 [[SUBVECIDX1]], 4 +; CHECK: [[VECIDX10:%.*]] = add i32 [[ELTBASE1]], 0 +; CHECK: [[ELT10:%.*]] = extractelement <16 x i8> [[SHUFF1]], i32 [[VECIDX10]] +; CHECK: [[VEC10:%.*]] = insertelement <4 x i8> undef, i8 [[ELT10]], i32 0 +; CHECK: [[VECIDX11:%.*]] = add i32 [[ELTBASE1]], 1 +; CHECK: [[ELT11:%.*]] = extractelement <16 x i8> [[SHUFF1]], i32 [[VECIDX11]] +; CHECK: [[VEC11:%.*]] = insertelement <4 x i8> [[VEC10]], i8 [[ELT11]], i32 1 +; CHECK: [[VECIDX12:%.*]] = add i32 [[ELTBASE1]], 2 +; CHECK: [[ELT12:%.*]] = extractelement <16 x i8> [[SHUFF1]], i32 [[VECIDX12]] +; CHECK: [[VEC12:%.*]] = insertelement <4 x i8> [[VEC11]], i8 [[ELT12]], i32 2 +; CHECK: [[VECIDX13:%.*]] = add i32 [[ELTBASE1]], 3 +; CHECK: [[ELT13:%.*]] = extractelement <16 x i8> [[SHUFF1]], i32 [[VECIDX13]] +; CHECK: [[VEC13:%.*]] = insertelement <4 x i8> [[VEC12]], i8 [[ELT13]], i32 3 + +; CHECK: [[DELTA2:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 2 +; CHECK: [[SHUFF2:%.*]] = call <16 x i8> @__mux_sub_group_shuffle_up_v16i8( +; CHECK-SAME: <16 x i8> [[LHS]], <16 x i8> [[RHS]], i32 [[DELTA2]]) +; CHECK: [[SUBVECIDX2:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 2 +; CHECK: [[ELTBASE2:%.*]] = mul i32 [[SUBVECIDX2]], 4 +; CHECK: [[VECIDX20:%.*]] = add i32 [[ELTBASE2]], 0 +; CHECK: [[ELT20:%.*]] = extractelement <16 x i8> [[SHUFF2]], i32 [[VECIDX20]] +; CHECK: [[VEC20:%.*]] = insertelement <4 x i8> undef, i8 [[ELT20]], i32 0 +; CHECK: [[VECIDX21:%.*]] = add i32 [[ELTBASE2]], 1 +; CHECK: [[ELT21:%.*]] = extractelement <16 x i8> [[SHUFF2]], i32 [[VECIDX21]] +; CHECK: [[VEC21:%.*]] = insertelement <4 x i8> [[VEC20]], i8 [[ELT21]], i32 1 +; CHECK: [[VECIDX22:%.*]] = add i32 [[ELTBASE2]], 2 +; CHECK: [[ELT22:%.*]] = extractelement <16 x i8> [[SHUFF2]], i32 [[VECIDX22]] +; CHECK: [[VEC22:%.*]] = insertelement <4 x i8> [[VEC21]], i8 [[ELT22]], i32 2 +; CHECK: [[VECIDX23:%.*]] = add i32 [[ELTBASE2]], 3 +; CHECK: [[ELT23:%.*]] = extractelement <16 x i8> [[SHUFF2]], i32 [[VECIDX23]] +; CHECK: [[VEC23:%.*]] = insertelement <4 x i8> [[VEC22]], i8 [[ELT23]], i32 3 + +; CHECK: [[DELTA3:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 3 +; CHECK: [[SHUFF3:%.*]] = call <16 x i8> @__mux_sub_group_shuffle_up_v16i8( +; CHECK-SAME: <16 x i8> [[LHS]], <16 x i8> [[RHS]], i32 [[DELTA3]]) +; CHECK: [[SUBVECIDX3:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 3 +; CHECK: [[ELTBASE3:%.*]] = mul i32 [[SUBVECIDX3]], 4 +; CHECK: [[VECIDX30:%.*]] = add i32 [[ELTBASE3]], 0 +; CHECK: [[ELT30:%.*]] = extractelement <16 x i8> [[SHUFF3]], i32 [[VECIDX30]] +; CHECK: [[VEC30:%.*]] = insertelement <4 x i8> undef, i8 [[ELT30]], i32 0 +; CHECK: [[VECIDX31:%.*]] = add i32 [[ELTBASE3]], 1 +; CHECK: [[ELT31:%.*]] = extractelement <16 x i8> [[SHUFF3]], i32 [[VECIDX31]] +; CHECK: [[VEC31:%.*]] = insertelement <4 x i8> [[VEC30]], i8 [[ELT31]], i32 1 +; CHECK: [[VECIDX32:%.*]] = add i32 [[ELTBASE3]], 2 +; CHECK: [[ELT32:%.*]] = extractelement <16 x i8> [[SHUFF3]], i32 [[VECIDX32]] +; CHECK: [[VEC32:%.*]] = insertelement <4 x i8> [[VEC31]], i8 [[ELT32]], i32 2 +; CHECK: [[VECIDX33:%.*]] = add i32 [[ELTBASE3]], 3 +; CHECK: [[ELT33:%.*]] = extractelement <16 x i8> [[SHUFF3]], i32 [[VECIDX33]] +; CHECK: [[VEC33:%.*]] = insertelement <4 x i8> [[VEC32]], i8 [[ELT33]], i32 3 +define spir_kernel void @kernel_vec_data(ptr %lhsptr, ptr %rhsptr, ptr %out) { + %gid = tail call i64 @__mux_get_global_id(i32 0) + %arrayidx.lhs = getelementptr inbounds <4 x i8>, ptr %lhsptr, i64 %gid + %lhs = load <4 x i8>, ptr %arrayidx.lhs, align 4 + %arrayidx.rhs = getelementptr inbounds <4 x i8>, ptr %rhsptr, i64 %gid + %rhs = load <4 x i8>, ptr %arrayidx.rhs, align 4 + %shuffle_up = call <4 x i8> @__mux_sub_group_shuffle_up_v4i8(<4 x i8> %lhs, <4 x i8> %rhs, i32 2) + %arrayidx.out = getelementptr inbounds <4 x i8>, ptr %out, i64 %gid + store <4 x i8> %shuffle_up, ptr %arrayidx.out, align 4 + ret void +} + +; CHECK-LABEL: define spir_kernel void @__vecz_v4_kernel_varying_delta(ptr %lhsptr, ptr %rhsptr, ptr %deltaptr, ptr %out) +; CHECK: [[LHS:%.*]] = load <4 x float>, ptr %arrayidx.lhs, align 4 +; CHECK: [[RHS:%.*]] = load <4 x float>, ptr %arrayidx.rhs, align 4 +; CHECK: [[DELTALD:%.*]] = load <4 x i32>, ptr %arrayidx.deltas, align 4 + +; CHECK: [[DELTAS:%.*]] = sub <4 x i32> {{%.*}}, [[DELTALD]] +; CHECK: [[QUOTIENT:%.*]] = sdiv <4 x i32> [[DELTAS]], +; CHECK: [[REMAINDER:%.*]] = srem <4 x i32> [[DELTAS]], + +; CHECK: [[ARGXOR:%.*]] = xor <4 x i32> [[DELTAS]], +; CHECK: [[SIGNDIFF:%.*]] = icmp slt <4 x i32> [[ARGXOR]], zeroinitializer +; CHECK: [[REMNONZERO:%.*]] = icmp ne <4 x i32> [[REMAINDER]], zeroinitializer +; CHECK: [[CONDITION:%.*]] = and <4 x i1> [[REMNONZERO]], [[SIGNDIFF]] + +; CHECK: [[MIN1:%.*]] = sub <4 x i32> [[QUOTIENT]], +; CHECK: [[PLUSR:%.*]] = add <4 x i32> [[REMAINDER]], + +; CHECK: [[MUXIDS:%.*]] = select <4 x i1> [[CONDITION]], <4 x i32> [[MIN1]], <4 x i32> [[QUOTIENT]] +; CHECK: [[VECELTS:%.*]] = select <4 x i1> [[CONDITION]], <4 x i32> [[PLUSR]], <4 x i32> [[REMAINDER]] + +; CHECK: [[MUXDELTAS:%.*]] = sub <4 x i32> {{%.*}}, [[MUXIDS]] + +; CHECK: [[DELTA0:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 0 +; CHECK: [[SHUFF0:%.*]] = call <4 x float> @__mux_sub_group_shuffle_up_v4f32( +; CHECK-SAME: <4 x float> [[LHS]], <4 x float> [[RHS]], i32 [[DELTA0]]) +; CHECK: [[VECIDX0:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 0 +; CHECK: [[ELT0:%.*]] = extractelement <4 x float> [[SHUFF0]], i32 [[VECIDX0]] + +; CHECK: [[DELTA1:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 1 +; CHECK: [[SHUFF1:%.*]] = call <4 x float> @__mux_sub_group_shuffle_up_v4f32( +; CHECK-SAME: <4 x float> [[LHS]], <4 x float> [[RHS]], i32 [[DELTA1]]) +; CHECK: [[VECIDX1:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 1 +; CHECK: [[ELT1:%.*]] = extractelement <4 x float> [[SHUFF1]], i32 [[VECIDX1]] + +; CHECK: [[DELTA2:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 2 +; CHECK: [[SHUFF2:%.*]] = call <4 x float> @__mux_sub_group_shuffle_up_v4f32( +; CHECK-SAME: <4 x float> [[LHS]], <4 x float> [[RHS]], i32 [[DELTA2]]) +; CHECK: [[VECIDX2:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 2 +; CHECK: [[ELT2:%.*]] = extractelement <4 x float> [[SHUFF2]], i32 [[VECIDX2]] + +; CHECK: [[DELTA3:%.*]] = extractelement <4 x i32> [[MUXDELTAS]], i32 3 +; CHECK: [[SHUFF3:%.*]] = call <4 x float> @__mux_sub_group_shuffle_up_v4f32( +; CHECK-SAME: <4 x float> [[LHS]], <4 x float> [[RHS]], i32 [[DELTA3]]) +; CHECK: [[VECIDX3:%.*]] = extractelement <4 x i32> [[VECELTS]], i32 3 +; CHECK: [[ELT3:%.*]] = extractelement <4 x float> [[SHUFF3]], i32 [[VECIDX3]] +define spir_kernel void @kernel_varying_delta(ptr %lhsptr, ptr %rhsptr, ptr %deltaptr, ptr %out) { %gid = tail call i64 @__mux_get_global_id(i32 0) - %arrayidx.in = getelementptr inbounds float, ptr %in, i64 %gid - %val = load float, ptr %arrayidx.in, align 8 - %shuffle_up = call float @__mux_sub_group_shuffle_up_f32(float %val, float %val, i32 1) + %arrayidx.lhs = getelementptr inbounds float, ptr %lhsptr, i64 %gid + %lhs = load float, ptr %arrayidx.lhs, align 4 + %arrayidx.rhs = getelementptr inbounds float, ptr %rhsptr, i64 %gid + %rhs = load float, ptr %arrayidx.rhs, align 4 + %arrayidx.deltas = getelementptr inbounds i32, ptr %deltaptr, i64 %gid + %delta = load i32, ptr %arrayidx.deltas, align 4 + %shuffle_up = call float @__mux_sub_group_shuffle_up_f32(float %lhs, float %rhs, i32 %delta) %arrayidx.out = getelementptr inbounds float, ptr %out, i64 %gid store float %shuffle_up, ptr %arrayidx.out, align 8 ret void @@ -34,3 +239,4 @@ define spir_kernel void @kernel(ptr %in, ptr %out) { declare i64 @__mux_get_global_id(i32) declare float @__mux_sub_group_shuffle_up_f32(float %prev, float %curr, i32 %delta) +declare <4 x i8> @__mux_sub_group_shuffle_up_v4i8(<4 x i8> %prev, <4 x i8> %curr, i32 %delta)