Skip to content

Commit

Permalink
[vecz] Support vector-predicated reductions natively
Browse files Browse the repository at this point in the history
The initial vecz support for vector-predication was implemented around
LLVM 12, before there were intrinsics available for reduction
operations. This meant that we had to work around the lack intrinsics by
using regular reduction intrinsics and 'sanitizing' the input by masking
out the unwanted vector elements with the neutral value.

Vector-predicated reduction intrinsics have been around since LLVM 14 so
it's high time we accommodate them natively. This should lead to better
code generation when vector-predicating kernels.
  • Loading branch information
frasercrmck committed Nov 6, 2023
1 parent 0b1216d commit 64d0ec5
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 182 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <llvm/ADT/DenseMap.h>
#include <llvm/ADT/SmallVector.h>
#include <llvm/Analysis/IVDescriptors.h>
#include <llvm/Analysis/TargetTransformInfo.h>
#include <llvm/IR/IRBuilder.h>
#include <multi_llvm/llvm_version.h>
#include <multi_llvm/multi_llvm.h>
Expand Down Expand Up @@ -87,18 +88,21 @@ bool createSubSplats(const vecz::TargetInfo &TI, llvm::IRBuilder<> &B,
llvm::SmallVectorImpl<llvm::Value *> &srcs,
unsigned subWidth);

/// @brief Utility function for sanitizing the input to a reduction when
/// vector-predicating. Since VP reduction intrinsics didn't land in LLVM 13,
/// reductions must ensure that elements past VL don't affect the result.
/// @brief Utility function for creating a reduction operation.
///
/// Only works on RecurKind::And, Or, Add, SMin, SMax, UMin, UMax, FAdd.
/// The value must be a vector.
///
/// @param[in] B IRBuilder to build any new instructions created
/// @param[in] Val The value to sanitize
/// @param[in] VL The vector length
/// @param[in] Kind The kind of reduction to sanitize for
llvm::Value *sanitizeVPReductionInput(llvm::IRBuilder<> &B, llvm::Value *Val,
llvm::Value *VL, llvm::RecurKind Kind);
/// If VL is passed and is non-null, it is assumed to be the i32 value
/// representing the active vector length. The reduction will be
/// vector-predicated according to this length.
///
/// Only works on RecurKind::And, Or, Xor, Add, Mul, FAdd, FMul, {S,U,F}Min,
/// {S,U,F}Max.
llvm::Value *createMaybeVPTargetReduction(llvm::IRBuilderBase &B,
const llvm::TargetTransformInfo &TTI,
llvm::Value *Val,
llvm::RecurKind Kind,
llvm::Value *VL = nullptr);

/// @brief Utility function to obtain an indices vector to be used in a gather
/// operation.
Expand All @@ -119,7 +123,7 @@ llvm::Value *getGatherIndicesVector(llvm::IRBuilder<> &B, llvm::Value *Indices,
const llvm::Twine &N = "");

/// @brief Returns a boolean vector with all elements set to 'true'.
llvm::Value *createAllTrueMask(llvm::IRBuilder<> &B, llvm::ElementCount EC);
llvm::Value *createAllTrueMask(llvm::IRBuilderBase &B, llvm::ElementCount EC);

/// @brief Returns an integer step vector, representing the sequence 0 ... N-1.
llvm::Value *createIndexSequence(llvm::IRBuilder<> &Builder,
Expand Down
78 changes: 67 additions & 11 deletions modules/compiler/vecz/source/transform/packetization_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include <llvm/Analysis/VectorUtils.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/Intrinsics.h>
#include <llvm/Transforms/Utils/LoopUtils.h>
#include <multi_llvm/multi_llvm.h>
#include <multi_llvm/vector_type_helper.h>

Expand Down Expand Up @@ -247,16 +249,70 @@ bool createSubSplats(const vecz::TargetInfo &TI, IRBuilder<> &B,
return true;
}

Value *sanitizeVPReductionInput(IRBuilder<> &B, Value *Val, Value *VL,
RecurKind Kind) {
Type *const ValTy = Val->getType();
ElementCount const EC = multi_llvm::getVectorElementCount(ValTy);
Value *const VLSplat = B.CreateVectorSplat(EC, VL);
Value *const IdxVec =
createIndexSequence(B, VectorType::get(VL->getType(), EC));
Value *const ActiveMask = B.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat);
auto *const NeutralVal = compiler::utils::getNeutralVal(Kind, ValTy);
return B.CreateSelect(ActiveMask, Val, NeutralVal);
Value *createMaybeVPTargetReduction(IRBuilderBase &B,
const TargetTransformInfo &TTI, Value *Val,
RecurKind Kind, Value *VL) {
assert(isa<VectorType>(Val->getType()) && "Must be vector type");
// If VL is null, it's not a vector-predicated reduction.
if (!VL) {
return createSimpleTargetReduction(B, &TTI, Val, Kind);
}
auto IntrinsicOp = Intrinsic::not_intrinsic;
switch (Kind) {
default:
break;
case RecurKind::None:
return nullptr;
case RecurKind::Add:
IntrinsicOp = Intrinsic::vp_reduce_add;
break;
case RecurKind::Mul:
IntrinsicOp = Intrinsic::vp_reduce_mul;
break;
case RecurKind::Or:
IntrinsicOp = Intrinsic::vp_reduce_or;
break;
case RecurKind::And:
IntrinsicOp = Intrinsic::vp_reduce_and;
break;
case RecurKind::Xor:
IntrinsicOp = Intrinsic::vp_reduce_xor;
break;
case RecurKind::FAdd:
IntrinsicOp = Intrinsic::vp_reduce_fadd;
break;
case RecurKind::FMul:
IntrinsicOp = Intrinsic::vp_reduce_fmul;
break;
case RecurKind::SMin:
IntrinsicOp = Intrinsic::vp_reduce_smin;
break;
case RecurKind::SMax:
IntrinsicOp = Intrinsic::vp_reduce_smax;
break;
case RecurKind::UMin:
IntrinsicOp = Intrinsic::vp_reduce_umin;
break;
case RecurKind::UMax:
IntrinsicOp = Intrinsic::vp_reduce_umax;
break;
case RecurKind::FMin:
IntrinsicOp = Intrinsic::vp_reduce_fmin;
break;
case RecurKind::FMax:
IntrinsicOp = Intrinsic::vp_reduce_fmax;
break;
}

auto *const F = Intrinsic::getDeclaration(B.GetInsertBlock()->getModule(),
IntrinsicOp, Val->getType());
assert(F && "Could not declare vector-predicated reduction intrinsic");

auto *const VecTy = cast<VectorType>(Val->getType());
auto *const NeutralVal =
compiler::utils::getNeutralVal(Kind, VecTy->getElementType());
auto *const Mask = createAllTrueMask(B, VecTy->getElementCount());
return B.CreateCall(F, {NeutralVal, Val, Mask, VL});
}

Value *getGatherIndicesVector(IRBuilder<> &B, Value *Indices, Type *Ty,
Expand All @@ -272,7 +328,7 @@ Value *getGatherIndicesVector(IRBuilder<> &B, Value *Indices, Type *Ty,
return B.CreateAdd(StepsMul, Indices, N);
}

Value *createAllTrueMask(IRBuilder<> &B, ElementCount EC) {
Value *createAllTrueMask(IRBuilderBase &B, ElementCount EC) {
return ConstantInt::getTrue(VectorType::get(B.getInt1Ty(), EC));
}

Expand Down
52 changes: 9 additions & 43 deletions modules/compiler/vecz/source/transform/packetizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -839,12 +839,7 @@ Value *Packetizer::Impl::reduceBranchCond(Value *cond, Instruction *terminator,
// value.
Value *&f = conds.front();

if (VL) {
f = sanitizeVPReductionInput(B, f, VL, kind);
VECZ_FAIL_IF(!f);
}

return createSimpleTargetReduction(B, &TTI, f, kind);
return createMaybeVPTargetReduction(B, TTI, f, kind, VL);
}

Packetizer::Result Packetizer::Impl::assign(Value *Scalar, Value *Vectorized) {
Expand Down Expand Up @@ -899,14 +894,7 @@ Packetizer::Result Packetizer::Impl::packetize(Value *V) {
if (newCond->getType()->isVectorTy()) {
IRBuilder<> B(Branch);
RecurKind kind = RecurKind::Or;
// Sanitize VP reduction inputs, if required.
if (VL) {
newCond = sanitizeVPReductionInput(B, newCond, VL, kind);
if (!newCond) {
return Packetizer::Result(*this);
}
}
newCond = createSimpleTargetReduction(B, &TTI, newCond, kind);
newCond = createMaybeVPTargetReduction(B, TTI, newCond, kind, VL);
}

Branch->setCondition(newCond);
Expand Down Expand Up @@ -1183,19 +1171,8 @@ Value *Packetizer::Impl::packetizeGroupReduction(Instruction *I) {
// them of ordering? See CA-3969.
op.getPacketValues(packetWidth, opPackets);

// When in VP mode, pre-sanitize the reduction input (before VP reduction
// intrinsics, introduced in LLVM 14)
if (VL) {
assert(opPackets.size() == 1 &&
"Should have bailed if dealing with more than one packet");
Value *&val = opPackets.front();
val = sanitizeVPReductionInput(B, val, VL, Info->Recurrence);
if (!val) {
emitVeczRemarkMissed(
&F, CI, "Can not vector-predicate workgroup/subgroup reduction");
return nullptr;
}
}
assert((!VL || packetWidth) &&
"Should have bailed if dealing with more than one VP packet");

// According to the OpenCL Spec, we are allowed to rearrange the operation
// order of a workgroup/subgroup reduction any way we like (even though
Expand All @@ -1216,8 +1193,8 @@ Value *Packetizer::Impl::packetizeGroupReduction(Instruction *I) {
}

// Reduce to a scalar.
Value *v =
createSimpleTargetReduction(B, &TTI, opPackets.front(), Info->Recurrence);
Value *v = createMaybeVPTargetReduction(B, TTI, opPackets.front(),
Info->Recurrence, VL);

// We leave the original reduction function and divert the vectorized
// reduction through it, giving us a reduction over the full apparent
Expand Down Expand Up @@ -1624,14 +1601,8 @@ Value *Packetizer::Impl::packetizeMaskVarying(Instruction *I) {
auto *maskInst = dyn_cast<Instruction>(vecMask);
IRBuilder<> B(maskInst ? buildAfter(maskInst, F) : I);

// Sanitize any vector-predicated inputs.
if (VL) {
vecMask = sanitizeVPReductionInput(B, vecMask, VL, RecurKind::Or);
VECZ_FAIL_IF(!vecMask);
}

Value *anyOfMask =
createSimpleTargetReduction(B, &TTI, vecMask, RecurKind::Or);
createMaybeVPTargetReduction(B, TTI, vecMask, RecurKind::Or, VL);
anyOfMask->setName("any_of_mask");

if (isVector) {
Expand Down Expand Up @@ -2072,13 +2043,8 @@ ValuePacket Packetizer::Impl::packetizeGroupScan(
// Thus we essentially keep the original group scan, but change it to be an
// exclusive one.
auto *Reduction = Ops.front();
if (VL) {
Reduction = sanitizeVPReductionInput(B, Reduction, VL, Scan.Recurrence);
if (!Reduction) {
return results;
}
}
Reduction = createSimpleTargetReduction(B, &TTI, Reduction, Scan.Recurrence);
Reduction =
createMaybeVPTargetReduction(B, TTI, Reduction, Scan.Recurrence, VL);

// Now we defer to an *exclusive* scan over the group.
auto ExclScan = Scan;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,4 @@ if.end: ; preds = %if.then, %entry

; CHECK: define spir_kernel void @__vecz_nxv2_vp_foo(ptr addrspace(1) nocapture readonly %a, ptr addrspace(1) nocapture %out)
; CHECK: [[CMP:%.*]] = fcmp oeq <vscale x 2 x float> %{{.*}}, zeroinitializer
; CHECK: [[INS:%.*]] = insertelement <vscale x 2 x i32> poison, i32 [[VL:%.*]], {{(i32|i64)}} 0
; CHECK: [[SPLAT:%.*]] = shufflevector <vscale x 2 x i32> [[INS]], <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer
; CHECK: [[IDX:%.*]] = call <vscale x 2 x i32> @llvm.experimental.stepvector.nxv2i32()
; CHECK: [[MASK:%.*]] = icmp ult <vscale x 2 x i32> [[IDX]], [[SPLAT]]
; CHECK: [[INP:%.*]] = select <vscale x 2 x i1> [[MASK]], <vscale x 2 x i1> [[CMP]], <vscale x 2 x i1> zeroinitializer
; CHECK: %{{.*}} = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> [[INP]])
; CHECK: %{{.*}} = call i1 @llvm.vp.reduce.or.nxv2i1(i1 false, <vscale x 2 x i1> [[CMP]], {{.*}}, i32 {{.*}})
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,7 @@ if.end:
ret void
; CHECK: define spir_kernel void @__vecz_nxv4_vp_mask_varying
; CHECK: [[CMP:%.*]] = icmp slt <vscale x 4 x i64> %{{.*}},
; CHECK: [[INS:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[VL:%.*]], {{(i32|i64)}} 0
; CHECK: [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[INS]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
; CHECK: [[IDX:%.*]] = call <vscale x 4 x i32> @llvm.experimental.stepvector.nxv4i32()
; CHECK: [[MASK:%.*]] = icmp ult <vscale x 4 x i32> [[IDX]], [[SPLAT]]
; CHECK: [[INP:%.*]] = select <vscale x 4 x i1> [[MASK]], <vscale x 4 x i1> [[CMP]], <vscale x 4 x i1> zeroinitializer
; CHECK: [[RED:%.*]] = call i1 @llvm.vector.reduce.or.nxv4i1(<vscale x 4 x i1> [[INP]])
; CHECK: [[RED:%.*]] = call i1 @llvm.vp.reduce.or.nxv4i1(i1 false, <vscale x 4 x i1> [[CMP]], {{.*}}, i32 {{.*}})
; CHECK: [[REINS:%.*]] = insertelement <4 x i1> poison, i1 [[RED]], {{(i32|i64)}} 0
; CHECK: [[RESPLAT:%.*]] = shufflevector <4 x i1> [[REINS]], <4 x i1> poison, <4 x i32> zeroinitializer
}
Expand Down
Loading

0 comments on commit 64d0ec5

Please sign in to comment.