Skip to content

Commit

Permalink
Merge pull request #65 from AinsleySnow/migrate
Browse files Browse the repository at this point in the history
[VP][LV] Migrate vector predication pass and intrinsics from llvm-bpevl.
  • Loading branch information
ChunyuLiao authored Mar 11, 2024
2 parents 2c203bb + f86030c commit e009baa
Show file tree
Hide file tree
Showing 31 changed files with 2,291 additions and 107 deletions.
11 changes: 11 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include "llvm/ADT/SmallBitVector.h"
#include "llvm/IR/FMF.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
Expand Down Expand Up @@ -1714,6 +1715,9 @@ class TargetTransformInfo {
/// \return The maximum number of function arguments the target supports.
unsigned getMaxNumArgs() const;

Value *computeVectorLength(IRBuilderBase &Builder, Value *AVL,
ElementCount VF) const;

/// @}

private:
Expand Down Expand Up @@ -2088,6 +2092,8 @@ class TargetTransformInfo::Concept {
getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
virtual bool hasArmWideBranch(bool Thumb) const = 0;
virtual unsigned getMaxNumArgs() const = 0;
virtual Value *computeVectorLength(IRBuilderBase &Builder, Value *AVL,
ElementCount VF) const = 0;
};

template <typename T>
Expand Down Expand Up @@ -2815,6 +2821,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
unsigned getMaxNumArgs() const override {
return Impl.getMaxNumArgs();
}

Value *computeVectorLength(IRBuilderBase &Builder, Value *AVL,
ElementCount VF) const override {
return Impl.computeVectorLength(Builder, AVL, VF);
}
};

template <typename T>
Expand Down
16 changes: 16 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/GetElementPtrTypeIterator.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/PatternMatch.h"
Expand Down Expand Up @@ -908,6 +909,21 @@ class TargetTransformInfoImplBase {

unsigned getMaxNumArgs() const { return UINT_MAX; }

Value *computeVectorLength(IRBuilderBase &Builder, Value *AVL,
ElementCount VF) const {
if (!VF.isScalable()) {
return ConstantInt::get(Builder.getInt32Ty(), VF.getFixedValue());
}

Constant *EC =
ConstantInt::get(Builder.getInt32Ty(), VF.getKnownMinValue());
Value *VLMax = Builder.CreateVScale(EC, "vlmax");
Value *VL = Builder.CreateZExtOrTrunc(AVL, Builder.getInt32Ty(), "vl");

return Builder.CreateIntrinsic(Builder.getInt32Ty(), Intrinsic::umin,
{VLMax, VL}, nullptr, "evl");
}

protected:
// Obtain the minimum required size to hold the value (without the sign)
// In case of a vector it returns the min required size for one element.
Expand Down
16 changes: 16 additions & 0 deletions llvm/include/llvm/CodeGen/BasicTTIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
Expand Down Expand Up @@ -2558,6 +2559,21 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {

InstructionCost getVectorSplitCost() { return 1; }

Value *computeVectorLength(IRBuilderBase &Builder, Value *AVL,
ElementCount VF) const {
if (!VF.isScalable()) {
return ConstantInt::get(Builder.getInt32Ty(), VF.getFixedValue());
}

Constant *EC =
ConstantInt::get(Builder.getInt32Ty(), VF.getKnownMinValue());
Value *VLMax = Builder.CreateVScale(EC, "vlmax");
Value *VL = Builder.CreateZExtOrTrunc(AVL, Builder.getInt32Ty(), "vl");

return Builder.CreateIntrinsic(Builder.getInt32Ty(), Intrinsic::umin,
{VLMax, VL}, nullptr, "evl");
}

/// @}
};

Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/IR/IntrinsicInst.h
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,9 @@ class VPIntrinsic : public IntrinsicInst {
/// The llvm.vp.* intrinsics for this instruction Opcode
static Intrinsic::ID getForOpcode(unsigned OC);

/// The llvm.vp.* intrinsics for this intrinsic ID
static Intrinsic::ID getForIntrinsicID(Intrinsic::ID IID);

// Whether \p ID is a VP intrinsic ID.
static bool isVPIntrinsic(Intrinsic::ID);

Expand Down
20 changes: 17 additions & 3 deletions llvm/include/llvm/IR/VectorBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class VectorBuilder {
return RetType();
}

Value *createVectorInstruction(Intrinsic::ID VPID, Type *ReturnTy,
ArrayRef<Value *> VecOpArray,
const Twine &Name = Twine());

public:
VectorBuilder(IRBuilderBase &Builder,
Behavior ErrorHandling = Behavior::ReportAndAbort)
Expand Down Expand Up @@ -89,9 +93,19 @@ class VectorBuilder {
// \p Opcode The functional instruction opcode of the emitted intrinsic.
// \p ReturnTy The return type of the operation.
// \p VecOpArray The operand list.
Value *createVectorInstruction(unsigned Opcode, Type *ReturnTy,
ArrayRef<Value *> VecOpArray,
const Twine &Name = Twine());
Value *createVectorInstructionFromOpcode(unsigned Opcode, Type *ReturnTy,
ArrayRef<Value *> VecOpArray,
const Twine &Name = Twine());

// Emit a VP intrinsic call that mimics a regular intrinsic.
// This operation behaves according to the VectorBuilderBehavior.
// \p IID The functional intrinsic ID of the emitted VP intrinsic.
// \p ReturnTy The return type of the operation.
// \p VecOpArray The operand list.
Value *createVectorInstructionFromIntrinsicID(Intrinsic::ID IID,
Type *ReturnTy,
ArrayRef<Value *> VecOpArray,
const Twine &Name = Twine());
};

} // namespace llvm
Expand Down
55 changes: 55 additions & 0 deletions llvm/include/llvm/Transforms/Vectorize/VectorPredication.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#ifndef LLVM_TRANSFORMS_VECTORPREDICATION_H
#define LLVM_TRANSFORMS_VECTORPREDICATION_H

#include "llvm/ADT/MapVector.h"
#include "llvm/IR/PassManager.h"

namespace llvm {

using InstToMaskEVLMap = DenseMap<Instruction *, std::pair<Value *, Value *>>;

struct BlockData {
// Vector that stores all vector predicated memory writing operations found in
// the basic block. If after phase 1 is empty, then the basic block can be
// skipped by following phases.
SmallVector<Instruction *> MemoryWritingVPInstructions;

// Store all instructions of the basic block (in the same order as they are
// found), assigning to each the list of users. Skip PHIs and terminators.
MapVector<Instruction *, SmallPtrSet<Instruction *, 4>> TopologicalGraph;

// Map each full-length vector operation eligible to be transformed to a
// vector predication one with the (mask,evl) pair of its first vector
// predicated memory writing operation user.
InstToMaskEVLMap VecOpsToTransform;

// Ordered list representing the reverse order of how the basic block has to
// be transformed due to the new vector predicated instructions.
SmallVector<Instruction *> NewBBReverseOrder;

BlockData() = default;
};

class VectorPredicationPass : public PassInfoMixin<VectorPredicationPass> {
private:
// List of instructions to be replaced by the new VP operations and that later
// should be removed, if possible.
DenseMap<Instruction *, Value *> OldInstructionsToRemove;

void analyseBasicBlock(BasicBlock &BB, BlockData &BBInfo);
void findCandidateVectorOperations(BasicBlock &BB, BlockData &BBInfo);
void addNewUsersToMasksAndEVLs(BasicBlock &BB, BlockData &BBInfo);
void buildNewBasicBlockSchedule(BasicBlock &BB, BlockData &BBInfo);
void emitNewBasicBlockSchedule(BasicBlock &BB, BlockData &BBInfo);
void transformCandidateVectorOperations(BasicBlock &BB, BlockData &BBInfo);

void removeOldInstructions();

public:
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
static StringRef name() { return "VectorPredicationPass"; }
};

} // namespace llvm

#endif // LLVM_TRANSFORMS_VECTORPREDICATION_H
6 changes: 6 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1264,6 +1264,12 @@ bool TargetTransformInfo::hasActiveVectorLength(unsigned Opcode, Type *DataType,
return TTIImpl->hasActiveVectorLength(Opcode, DataType, Alignment);
}

Value *TargetTransformInfo::computeVectorLength(IRBuilderBase &Builder,
Value *AVL,
ElementCount VF) const {
return TTIImpl->computeVectorLength(Builder, AVL, VF);
}

TargetTransformInfo::Concept::~Concept() = default;

TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}
Expand Down
13 changes: 13 additions & 0 deletions llvm/lib/IR/IntrinsicInst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,19 @@ Intrinsic::ID VPIntrinsic::getForOpcode(unsigned IROPC) {
return Intrinsic::not_intrinsic;
}

Intrinsic::ID VPIntrinsic::getForIntrinsicID(Intrinsic::ID IID) {
switch (IID) {
default:
break;

#define BEGIN_REGISTER_VP_INTRINSIC(VPID, ...) break;
#define VP_PROPERTY_FUNCTIONAL_INTRINSIC(INTR) case Intrinsic::INTR:
#define END_REGISTER_VP_INTRINSIC(VPID) return Intrinsic::VPID;
#include "llvm/IR/VPIntrinsics.def"
}
return Intrinsic::not_intrinsic;
}

bool VPIntrinsic::canIgnoreVectorLengthParam() const {
using namespace PatternMatch;

Expand Down
23 changes: 20 additions & 3 deletions llvm/lib/IR/VectorBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,30 @@ Value &VectorBuilder::requestEVL() {
return *ConstantInt::get(IntTy, StaticVectorLength.getFixedValue());
}

Value *VectorBuilder::createVectorInstruction(unsigned Opcode, Type *ReturnTy,
ArrayRef<Value *> InstOpArray,
const Twine &Name) {
Value *VectorBuilder::createVectorInstructionFromOpcode(
unsigned Opcode, Type *ReturnTy, ArrayRef<Value *> InstOpArray,
const Twine &Name) {
auto VPID = VPIntrinsic::getForOpcode(Opcode);
if (VPID == Intrinsic::not_intrinsic)
return returnWithError<Value *>("No VPIntrinsic for this opcode");

return createVectorInstruction(VPID, ReturnTy, InstOpArray, Name);
}

Value *VectorBuilder::createVectorInstructionFromIntrinsicID(
Intrinsic::ID IID, Type *ReturnTy, ArrayRef<Value *> InstOpArray,
const Twine &Name) {
auto VPID = VPIntrinsic::getForIntrinsicID(IID);
if (VPID == Intrinsic::not_intrinsic)
return returnWithError<Value *>("No VPIntrinsic for this Intrinsic");

return createVectorInstruction(VPID, ReturnTy, InstOpArray, Name);
}

Value *VectorBuilder::createVectorInstruction(Intrinsic::ID VPID,
Type *ReturnTy,
ArrayRef<Value *> InstOpArray,
const Twine &Name) {
auto MaskPosOpt = VPIntrinsic::getMaskParamPos(VPID);
auto VLenPosOpt = VPIntrinsic::getVectorLengthParamPos(VPID);
size_t NumInstParams = InstOpArray.size();
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Passes/PassBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@
#include "llvm/Transforms/Vectorize/LoopVectorize.h"
#include "llvm/Transforms/Vectorize/SLPVectorizer.h"
#include "llvm/Transforms/Vectorize/VectorCombine.h"
#include "llvm/Transforms/Vectorize/VectorPredication.h"
#include <optional>

using namespace llvm;
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Passes/PassBuilderPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@
#include "llvm/Transforms/Vectorize/LoopVectorize.h"
#include "llvm/Transforms/Vectorize/SLPVectorizer.h"
#include "llvm/Transforms/Vectorize/VectorCombine.h"
#include "llvm/Transforms/Vectorize/VectorPredication.h"

using namespace llvm;

Expand Down Expand Up @@ -285,6 +286,11 @@ cl::opt<bool> EnableMemProfContextDisambiguation(
extern cl::opt<bool> EnableInferAlignmentPass;
} // namespace llvm

static cl::opt<bool>
EnableVectorPredication("enable-vector-predication", cl::init(false),
cl::Hidden,
cl::desc("Enable VectorPredicationPass."));

PipelineTuningOptions::PipelineTuningOptions() {
LoopInterleaving = true;
LoopVectorization = true;
Expand Down Expand Up @@ -1297,6 +1303,10 @@ void PassBuilder::addVectorPasses(OptimizationLevel Level,
/*AllowSpeculation=*/true),
/*UseMemorySSA=*/true, /*UseBlockFrequencyInfo=*/false));

// Try to vector predicate vectorized functions.
if (EnableVectorPredication)
FPM.addPass(VectorPredicationPass());

// Now that we've vectorized and unrolled loops, we may have more refined
// alignment information, try to re-derive it here.
FPM.addPass(AlignmentFromAssumptionsPass());
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Passes/PassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ FUNCTION_PASS("tsan", ThreadSanitizerPass())
FUNCTION_PASS("typepromotion", TypePromotionPass(TM))
FUNCTION_PASS("unify-loop-exits", UnifyLoopExitsPass())
FUNCTION_PASS("vector-combine", VectorCombinePass())
FUNCTION_PASS("vector-predication", VectorPredicationPass())
FUNCTION_PASS("verify", VerifierPass())
FUNCTION_PASS("verify<domtree>", DominatorTreeVerifierPass())
FUNCTION_PASS("verify<loops>", LoopVerifierPass())
Expand Down
34 changes: 34 additions & 0 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "llvm/CodeGen/CostTable.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicsRISCV.h"
#include <cmath>
#include <optional>
using namespace llvm;
Expand Down Expand Up @@ -1848,3 +1849,36 @@ bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
C2.NumIVMuls, C2.NumBaseAdds,
C2.ScaleCost, C2.ImmCost, C2.SetupCost);
}

Value *RISCVTTIImpl::computeVectorLength(IRBuilderBase &Builder, Value *AVL,
ElementCount VF) const {
// Maps a VF to a (SEW, LMUL) pair.
// NOTE: we assume ELEN = 64.
const std::map<unsigned int, std::pair<unsigned int, unsigned int>>
VFToSEWLMUL = {{1, {3, 0}}, {2, {3, 1}}, {4, {3, 2}}, {8, {3, 3}},
{16, {2, 3}}, {32, {1, 3}}, {64, {0, 3}}};

assert(AVL->getType()->isIntegerTy() &&
"Requested vector length should be an integer.");
assert(VFToSEWLMUL.find(VF.getKnownMinValue()) != VFToSEWLMUL.end() &&
"Invalid value for LMUL argument.");
auto VFToSEWLMULVal = VFToSEWLMUL.at(VF.getKnownMinValue());

Value *AVLArg = Builder.CreateZExtOrTrunc(AVL, Builder.getInt64Ty());
Constant *SEWArg =
ConstantInt::get(Builder.getInt64Ty(), VFToSEWLMULVal.first);
Constant *LMULArg =
ConstantInt::get(Builder.getInt64Ty(), VFToSEWLMULVal.second);
Value *EVLRes =
Builder.CreateIntrinsic(Intrinsic::riscv_vsetvli, {AVLArg->getType()},
{AVLArg, SEWArg, LMULArg}, nullptr, "vl");

// NOTE: evl type is required to be i32.
Value *EVL = Builder.CreateZExtOrTrunc(EVLRes, Builder.getInt32Ty());
if (!VF.isScalable()) {
EVL = Builder.CreateBinaryIntrinsic(
Intrinsic::umin,
ConstantInt::get(Builder.getInt32Ty(), VF.getFixedValue()), EVL);
}
return EVL;
}
3 changes: 3 additions & 0 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
bool shouldFoldTerminatingConditionAfterLSR() const {
return true;
}

Value *computeVectorLength(IRBuilderBase &Builder, Value *AVL,
ElementCount VF) const;
};

} // end namespace llvm
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/Vectorize/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_llvm_component_library(LLVMVectorize
SLPVectorizer.cpp
Vectorize.cpp
VectorCombine.cpp
VectorPredication.cpp
VPlan.cpp
VPlanAnalysis.cpp
VPlanHCFGBuilder.cpp
Expand Down
Loading

0 comments on commit e009baa

Please sign in to comment.