Skip to content

Commit

Permalink
Merge pull request #192 from frasercrmck/native-vp-reductions
Browse files Browse the repository at this point in the history
[vecz] Support vector-predicated reductions natively
  • Loading branch information
frasercrmck authored Nov 6, 2023
2 parents 5d6a57c + 64d0ec5 commit 7bea3c5
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 182 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <llvm/ADT/DenseMap.h>
#include <llvm/ADT/SmallVector.h>
#include <llvm/Analysis/IVDescriptors.h>
#include <llvm/Analysis/TargetTransformInfo.h>
#include <llvm/IR/IRBuilder.h>
#include <multi_llvm/llvm_version.h>
#include <multi_llvm/multi_llvm.h>
Expand Down Expand Up @@ -87,18 +88,21 @@ bool createSubSplats(const vecz::TargetInfo &TI, llvm::IRBuilder<> &B,
llvm::SmallVectorImpl<llvm::Value *> &srcs,
unsigned subWidth);

/// @brief Utility function for sanitizing the input to a reduction when
/// vector-predicating. Since VP reduction intrinsics didn't land in LLVM 13,
/// reductions must ensure that elements past VL don't affect the result.
/// @brief Utility function for creating a reduction operation.
///
/// Only works on RecurKind::And, Or, Add, SMin, SMax, UMin, UMax, FAdd.
/// The value must be a vector.
///
/// @param[in] B IRBuilder to build any new instructions created
/// @param[in] Val The value to sanitize
/// @param[in] VL The vector length
/// @param[in] Kind The kind of reduction to sanitize for
llvm::Value *sanitizeVPReductionInput(llvm::IRBuilder<> &B, llvm::Value *Val,
llvm::Value *VL, llvm::RecurKind Kind);
/// If VL is passed and is non-null, it is assumed to be the i32 value
/// representing the active vector length. The reduction will be
/// vector-predicated according to this length.
///
/// Only works on RecurKind::And, Or, Xor, Add, Mul, FAdd, FMul, {S,U,F}Min,
/// {S,U,F}Max.
llvm::Value *createMaybeVPTargetReduction(llvm::IRBuilderBase &B,
const llvm::TargetTransformInfo &TTI,
llvm::Value *Val,
llvm::RecurKind Kind,
llvm::Value *VL = nullptr);

/// @brief Utility function to obtain an indices vector to be used in a gather
/// operation.
Expand All @@ -119,7 +123,7 @@ llvm::Value *getGatherIndicesVector(llvm::IRBuilder<> &B, llvm::Value *Indices,
const llvm::Twine &N = "");

/// @brief Returns a boolean vector with all elements set to 'true'.
llvm::Value *createAllTrueMask(llvm::IRBuilder<> &B, llvm::ElementCount EC);
llvm::Value *createAllTrueMask(llvm::IRBuilderBase &B, llvm::ElementCount EC);

/// @brief Returns an integer step vector, representing the sequence 0 ... N-1.
llvm::Value *createIndexSequence(llvm::IRBuilder<> &Builder,
Expand Down
78 changes: 67 additions & 11 deletions modules/compiler/vecz/source/transform/packetization_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include <llvm/Analysis/VectorUtils.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/Intrinsics.h>
#include <llvm/Transforms/Utils/LoopUtils.h>
#include <multi_llvm/multi_llvm.h>
#include <multi_llvm/vector_type_helper.h>

Expand Down Expand Up @@ -247,16 +249,70 @@ bool createSubSplats(const vecz::TargetInfo &TI, IRBuilder<> &B,
return true;
}

Value *sanitizeVPReductionInput(IRBuilder<> &B, Value *Val, Value *VL,
RecurKind Kind) {
Type *const ValTy = Val->getType();
ElementCount const EC = multi_llvm::getVectorElementCount(ValTy);
Value *const VLSplat = B.CreateVectorSplat(EC, VL);
Value *const IdxVec =
createIndexSequence(B, VectorType::get(VL->getType(), EC));
Value *const ActiveMask = B.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat);
auto *const NeutralVal = compiler::utils::getNeutralVal(Kind, ValTy);
return B.CreateSelect(ActiveMask, Val, NeutralVal);
Value *createMaybeVPTargetReduction(IRBuilderBase &B,
const TargetTransformInfo &TTI, Value *Val,
RecurKind Kind, Value *VL) {
assert(isa<VectorType>(Val->getType()) && "Must be vector type");
// If VL is null, it's not a vector-predicated reduction.
if (!VL) {
return createSimpleTargetReduction(B, &TTI, Val, Kind);
}
auto IntrinsicOp = Intrinsic::not_intrinsic;
switch (Kind) {
default:
break;
case RecurKind::None:
return nullptr;
case RecurKind::Add:
IntrinsicOp = Intrinsic::vp_reduce_add;
break;
case RecurKind::Mul:
IntrinsicOp = Intrinsic::vp_reduce_mul;
break;
case RecurKind::Or:
IntrinsicOp = Intrinsic::vp_reduce_or;
break;
case RecurKind::And:
IntrinsicOp = Intrinsic::vp_reduce_and;
break;
case RecurKind::Xor:
IntrinsicOp = Intrinsic::vp_reduce_xor;
break;
case RecurKind::FAdd:
IntrinsicOp = Intrinsic::vp_reduce_fadd;
break;
case RecurKind::FMul:
IntrinsicOp = Intrinsic::vp_reduce_fmul;
break;
case RecurKind::SMin:
IntrinsicOp = Intrinsic::vp_reduce_smin;
break;
case RecurKind::SMax:
IntrinsicOp = Intrinsic::vp_reduce_smax;
break;
case RecurKind::UMin:
IntrinsicOp = Intrinsic::vp_reduce_umin;
break;
case RecurKind::UMax:
IntrinsicOp = Intrinsic::vp_reduce_umax;
break;
case RecurKind::FMin:
IntrinsicOp = Intrinsic::vp_reduce_fmin;
break;
case RecurKind::FMax:
IntrinsicOp = Intrinsic::vp_reduce_fmax;
break;
}

auto *const F = Intrinsic::getDeclaration(B.GetInsertBlock()->getModule(),
IntrinsicOp, Val->getType());
assert(F && "Could not declare vector-predicated reduction intrinsic");

auto *const VecTy = cast<VectorType>(Val->getType());
auto *const NeutralVal =
compiler::utils::getNeutralVal(Kind, VecTy->getElementType());
auto *const Mask = createAllTrueMask(B, VecTy->getElementCount());
return B.CreateCall(F, {NeutralVal, Val, Mask, VL});
}

Value *getGatherIndicesVector(IRBuilder<> &B, Value *Indices, Type *Ty,
Expand All @@ -272,7 +328,7 @@ Value *getGatherIndicesVector(IRBuilder<> &B, Value *Indices, Type *Ty,
return B.CreateAdd(StepsMul, Indices, N);
}

Value *createAllTrueMask(IRBuilder<> &B, ElementCount EC) {
Value *createAllTrueMask(IRBuilderBase &B, ElementCount EC) {
return ConstantInt::getTrue(VectorType::get(B.getInt1Ty(), EC));
}

Expand Down
52 changes: 9 additions & 43 deletions modules/compiler/vecz/source/transform/packetizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -839,12 +839,7 @@ Value *Packetizer::Impl::reduceBranchCond(Value *cond, Instruction *terminator,
// value.
Value *&f = conds.front();

if (VL) {
f = sanitizeVPReductionInput(B, f, VL, kind);
VECZ_FAIL_IF(!f);
}

return createSimpleTargetReduction(B, &TTI, f, kind);
return createMaybeVPTargetReduction(B, TTI, f, kind, VL);
}

Packetizer::Result Packetizer::Impl::assign(Value *Scalar, Value *Vectorized) {
Expand Down Expand Up @@ -899,14 +894,7 @@ Packetizer::Result Packetizer::Impl::packetize(Value *V) {
if (newCond->getType()->isVectorTy()) {
IRBuilder<> B(Branch);
RecurKind kind = RecurKind::Or;
// Sanitize VP reduction inputs, if required.
if (VL) {
newCond = sanitizeVPReductionInput(B, newCond, VL, kind);
if (!newCond) {
return Packetizer::Result(*this);
}
}
newCond = createSimpleTargetReduction(B, &TTI, newCond, kind);
newCond = createMaybeVPTargetReduction(B, TTI, newCond, kind, VL);
}

Branch->setCondition(newCond);
Expand Down Expand Up @@ -1183,19 +1171,8 @@ Value *Packetizer::Impl::packetizeGroupReduction(Instruction *I) {
// them of ordering? See CA-3969.
op.getPacketValues(packetWidth, opPackets);

// When in VP mode, pre-sanitize the reduction input (before VP reduction
// intrinsics, introduced in LLVM 14)
if (VL) {
assert(opPackets.size() == 1 &&
"Should have bailed if dealing with more than one packet");
Value *&val = opPackets.front();
val = sanitizeVPReductionInput(B, val, VL, Info->Recurrence);
if (!val) {
emitVeczRemarkMissed(
&F, CI, "Can not vector-predicate workgroup/subgroup reduction");
return nullptr;
}
}
assert((!VL || packetWidth) &&
"Should have bailed if dealing with more than one VP packet");

// According to the OpenCL Spec, we are allowed to rearrange the operation
// order of a workgroup/subgroup reduction any way we like (even though
Expand All @@ -1216,8 +1193,8 @@ Value *Packetizer::Impl::packetizeGroupReduction(Instruction *I) {
}

// Reduce to a scalar.
Value *v =
createSimpleTargetReduction(B, &TTI, opPackets.front(), Info->Recurrence);
Value *v = createMaybeVPTargetReduction(B, TTI, opPackets.front(),
Info->Recurrence, VL);

// We leave the original reduction function and divert the vectorized
// reduction through it, giving us a reduction over the full apparent
Expand Down Expand Up @@ -1624,14 +1601,8 @@ Value *Packetizer::Impl::packetizeMaskVarying(Instruction *I) {
auto *maskInst = dyn_cast<Instruction>(vecMask);
IRBuilder<> B(maskInst ? buildAfter(maskInst, F) : I);

// Sanitize any vector-predicated inputs.
if (VL) {
vecMask = sanitizeVPReductionInput(B, vecMask, VL, RecurKind::Or);
VECZ_FAIL_IF(!vecMask);
}

Value *anyOfMask =
createSimpleTargetReduction(B, &TTI, vecMask, RecurKind::Or);
createMaybeVPTargetReduction(B, TTI, vecMask, RecurKind::Or, VL);
anyOfMask->setName("any_of_mask");

if (isVector) {
Expand Down Expand Up @@ -2072,13 +2043,8 @@ ValuePacket Packetizer::Impl::packetizeGroupScan(
// Thus we essentially keep the original group scan, but change it to be an
// exclusive one.
auto *Reduction = Ops.front();
if (VL) {
Reduction = sanitizeVPReductionInput(B, Reduction, VL, Scan.Recurrence);
if (!Reduction) {
return results;
}
}
Reduction = createSimpleTargetReduction(B, &TTI, Reduction, Scan.Recurrence);
Reduction =
createMaybeVPTargetReduction(B, TTI, Reduction, Scan.Recurrence, VL);

// Now we defer to an *exclusive* scan over the group.
auto ExclScan = Scan;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,4 @@ if.end: ; preds = %if.then, %entry

; CHECK: define spir_kernel void @__vecz_nxv2_vp_foo(ptr addrspace(1) nocapture readonly %a, ptr addrspace(1) nocapture %out)
; CHECK: [[CMP:%.*]] = fcmp oeq <vscale x 2 x float> %{{.*}}, zeroinitializer
; CHECK: [[INS:%.*]] = insertelement <vscale x 2 x i32> poison, i32 [[VL:%.*]], {{(i32|i64)}} 0
; CHECK: [[SPLAT:%.*]] = shufflevector <vscale x 2 x i32> [[INS]], <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer
; CHECK: [[IDX:%.*]] = call <vscale x 2 x i32> @llvm.experimental.stepvector.nxv2i32()
; CHECK: [[MASK:%.*]] = icmp ult <vscale x 2 x i32> [[IDX]], [[SPLAT]]
; CHECK: [[INP:%.*]] = select <vscale x 2 x i1> [[MASK]], <vscale x 2 x i1> [[CMP]], <vscale x 2 x i1> zeroinitializer
; CHECK: %{{.*}} = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> [[INP]])
; CHECK: %{{.*}} = call i1 @llvm.vp.reduce.or.nxv2i1(i1 false, <vscale x 2 x i1> [[CMP]], {{.*}}, i32 {{.*}})
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,7 @@ if.end:
ret void
; CHECK: define spir_kernel void @__vecz_nxv4_vp_mask_varying
; CHECK: [[CMP:%.*]] = icmp slt <vscale x 4 x i64> %{{.*}},
; CHECK: [[INS:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[VL:%.*]], {{(i32|i64)}} 0
; CHECK: [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[INS]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
; CHECK: [[IDX:%.*]] = call <vscale x 4 x i32> @llvm.experimental.stepvector.nxv4i32()
; CHECK: [[MASK:%.*]] = icmp ult <vscale x 4 x i32> [[IDX]], [[SPLAT]]
; CHECK: [[INP:%.*]] = select <vscale x 4 x i1> [[MASK]], <vscale x 4 x i1> [[CMP]], <vscale x 4 x i1> zeroinitializer
; CHECK: [[RED:%.*]] = call i1 @llvm.vector.reduce.or.nxv4i1(<vscale x 4 x i1> [[INP]])
; CHECK: [[RED:%.*]] = call i1 @llvm.vp.reduce.or.nxv4i1(i1 false, <vscale x 4 x i1> [[CMP]], {{.*}}, i32 {{.*}})
; CHECK: [[REINS:%.*]] = insertelement <4 x i1> poison, i1 [[RED]], {{(i32|i64)}} 0
; CHECK: [[RESPLAT:%.*]] = shufflevector <4 x i1> [[REINS]], <4 x i1> poison, <4 x i32> zeroinitializer
}
Expand Down
Loading

0 comments on commit 7bea3c5

Please sign in to comment.