Skip to content

Commit 47a9b69

Browse files
[LV][VPlan] Add initial support for CSA vectorization
This patch adds initial support for CSA vectorization LLVM. This new class can be characterized by vectorization of assignment to a scalar in a loop, such that the assignment is conditional from the perspective of its use. An assignment is conditional in a loop if a value may or may not be assigned in the loop body. For example: ``` int t = init_val; for (int i = 0; i < N; i++) { if (cond[i]) t = a[i]; } s = t; // use t ``` Using pseudo-LLVM code this can be vectorized as ``` vector.ph: ... %t = %init_val %init.mask = <all-false-vec> %init.data = <poison-vec> ; uninitialized vector.body: ... %mask.phi = phi [%init.mask, %vector.ph], [%new.mask, %vector.body] %data.phi = phi [%data.mask, %vector.ph], [%new.mask, %vector.body] %cond.vec = <widened-cmp> ... %a.vec = <widened-load> %a, %i %b = <any-lane-active> %cond.vec %new.mask = select %b, %cond.vec, %mask.phi %new.data = select %b, %a.vec, %data.phi ... middle.block: %s = <extract-last-active-lane> %new.mask, %new.data ``` On each iteration, we track whether any lane in the widened condition was active, and if it was take the current mask and data as the new mask and data vector. Then at the end of the loop, the scalar can be extracted only once. This transformation works the same way for integer, pointer, and floating point conditional assignment, since the transformation does not require inspection of the data being assigned. In the vectorization of a CSA, we will be introducing recipes into the vector preheader, the vector body, and the middle block. Recipes that are introduced into the preheader and middle block are executed only one time, and recipes that are in the vector body will be possibly executed multiple times. The more times that the vector body is executed, the less of an impact the preheader and middle block cost have on the overall cost of a CSA. A detailed explanation of the concept can be found [here](https://discourse.llvm.org/t/vectorization-of-conditional-scalar-assignment-csa/80964). This patch is further tested in llvm/llvm-test-suite#155. This patch contains only the non-EVL related code. The is based on the larger patch of llvm#106560, which contained both EVL and non-EVL related parts.
1 parent 9ab16d4 commit 47a9b69

18 files changed

+3751
-20
lines changed

Diff for: llvm/include/llvm/Analysis/IVDescriptors.h

+57-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file "describes" induction and recurrence variables.
9+
// This file "describes" induction, recurrence, and conditional scalar
10+
// assignment (CSA) variables.
1011
//
1112
//===----------------------------------------------------------------------===//
1213

@@ -423,6 +424,61 @@ class InductionDescriptor {
423424
SmallVector<Instruction *, 2> RedundantCasts;
424425
};
425426

427+
/// A Conditional Scalar Assignment (CSA) is an assignment from an initial
428+
/// scalar that may or may not occur.
429+
class CSADescriptor {
430+
/// If the conditional assignment occurs inside a loop, then Phi chooses
431+
/// the value of the assignment from the entry block or the loop body block.
432+
PHINode *Phi = nullptr;
433+
434+
/// The initial value of the CSA. If the condition guarding the assignment is
435+
/// not met, then the assignment retains this value.
436+
Value *InitScalar = nullptr;
437+
438+
/// The Instruction that conditionally assigned to inside the loop.
439+
Instruction *Assignment = nullptr;
440+
441+
/// Create a CSA Descriptor that models a valid CSA with its members
442+
/// initialized correctly.
443+
CSADescriptor(PHINode *Phi, Instruction *Assignment, Value *InitScalar)
444+
: Phi(Phi), InitScalar(InitScalar), Assignment(Assignment) {}
445+
446+
public:
447+
/// Create a CSA Descriptor that models an invalid CSA.
448+
CSADescriptor() = default;
449+
450+
/// If Phi is the root of a CSA, set CSADesc as the CSA rooted by
451+
/// Phi. Otherwise, return a false, leaving CSADesc unmodified.
452+
static bool isCSAPhi(PHINode *Phi, Loop *TheLoop, CSADescriptor &CSADesc);
453+
454+
operator bool() const { return isValid(); }
455+
456+
/// Returns whether SI is the Assignment in CSA
457+
static bool isCSASelect(CSADescriptor Desc, SelectInst *SI) {
458+
return Desc.getAssignment() == SI;
459+
}
460+
461+
/// Return whether this CSADescriptor models a valid CSA.
462+
bool isValid() const { return Phi && InitScalar && Assignment; }
463+
464+
/// Return the PHI that roots this CSA.
465+
PHINode *getPhi() const { return Phi; }
466+
467+
/// Return the initial value of the CSA. This is the value if the conditional
468+
/// assignment does not occur.
469+
Value *getInitScalar() const { return InitScalar; }
470+
471+
/// The Instruction that is used after the loop
472+
Instruction *getAssignment() const { return Assignment; }
473+
474+
/// Return the condition that this CSA is conditional upon.
475+
Value *getCond() const {
476+
if (auto *SI = dyn_cast_or_null<SelectInst>(Assignment))
477+
return SI->getCondition();
478+
return nullptr;
479+
}
480+
};
481+
426482
} // end namespace llvm
427483

428484
#endif // LLVM_ANALYSIS_IVDESCRIPTORS_H

Diff for: llvm/include/llvm/Analysis/TargetTransformInfo.h

+9
Original file line numberDiff line numberDiff line change
@@ -1852,6 +1852,10 @@ class TargetTransformInfo {
18521852
: EVLParamStrategy(EVLParamStrategy), OpStrategy(OpStrategy) {}
18531853
};
18541854

1855+
/// \returns true if the loop vectorizer should vectorize conditional
1856+
/// scalar assignments for the target.
1857+
bool enableCSAVectorization() const;
1858+
18551859
/// \returns How the target needs this vector-predicated operation to be
18561860
/// transformed.
18571861
VPLegalization getVPLegalizationStrategy(const VPIntrinsic &PI) const;
@@ -2305,6 +2309,7 @@ class TargetTransformInfo::Concept {
23052309
SmallVectorImpl<Use *> &OpsToSink) const = 0;
23062310

23072311
virtual bool isVectorShiftByScalarCheap(Type *Ty) const = 0;
2312+
virtual bool enableCSAVectorization() const = 0;
23082313
virtual VPLegalization
23092314
getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
23102315
virtual bool hasArmWideBranch(bool Thumb) const = 0;
@@ -3130,6 +3135,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
31303135
return Impl.isVectorShiftByScalarCheap(Ty);
31313136
}
31323137

3138+
bool enableCSAVectorization() const override {
3139+
return Impl.enableCSAVectorization();
3140+
}
3141+
31333142
VPLegalization
31343143
getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
31353144
return Impl.getVPLegalizationStrategy(PI);

Diff for: llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

+2
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,8 @@ class TargetTransformInfoImplBase {
10301030

10311031
bool isVectorShiftByScalarCheap(Type *Ty) const { return false; }
10321032

1033+
bool enableCSAVectorization() const { return false; }
1034+
10331035
TargetTransformInfo::VPLegalization
10341036
getVPLegalizationStrategy(const VPIntrinsic &PI) const {
10351037
return TargetTransformInfo::VPLegalization(

Diff for: llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h

+17
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,10 @@ class LoopVectorizationLegality {
269269
/// induction descriptor.
270270
using InductionList = MapVector<PHINode *, InductionDescriptor>;
271271

272+
/// CSAList contains the CSA descriptors for all the CSAs that were found
273+
/// in the loop, rooted by their phis.
274+
using CSAList = MapVector<PHINode *, CSADescriptor>;
275+
272276
/// RecurrenceSet contains the phi nodes that are recurrences other than
273277
/// inductions and reductions.
274278
using RecurrenceSet = SmallPtrSet<const PHINode *, 8>;
@@ -321,6 +325,12 @@ class LoopVectorizationLegality {
321325
/// Returns True if V is a Phi node of an induction variable in this loop.
322326
bool isInductionPhi(const Value *V) const;
323327

328+
/// Returns the CSAs found in the loop.
329+
const CSAList &getCSAs() const { return CSAs; }
330+
331+
/// Returns true if Phi is the root of a CSA in the loop.
332+
bool isCSAPhi(PHINode *Phi) const { return CSAs.count(Phi) != 0; }
333+
324334
/// Returns a pointer to the induction descriptor, if \p Phi is an integer or
325335
/// floating point induction.
326336
const InductionDescriptor *getIntOrFpInductionDescriptor(PHINode *Phi) const;
@@ -550,6 +560,10 @@ class LoopVectorizationLegality {
550560
void addInductionPhi(PHINode *Phi, const InductionDescriptor &ID,
551561
SmallPtrSetImpl<Value *> &AllowedExit);
552562

563+
/// Updates the vetorization state by adding \p Phi to the CSA list.
564+
void addCSAPhi(PHINode *Phi, const CSADescriptor &CSADesc,
565+
SmallPtrSetImpl<Value *> &AllowedExit);
566+
553567
/// The loop that we evaluate.
554568
Loop *TheLoop;
555569

@@ -594,6 +608,9 @@ class LoopVectorizationLegality {
594608
/// variables can be pointers.
595609
InductionList Inductions;
596610

611+
/// Holds the conditional scalar assignments
612+
CSAList CSAs;
613+
597614
/// Holds all the casts that participate in the update chain of the induction
598615
/// variables, and that have been proven to be redundant (possibly under a
599616
/// runtime guard). These casts can be ignored when creating the vectorized

Diff for: llvm/lib/Analysis/IVDescriptors.cpp

+57-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file "describes" induction and recurrence variables.
9+
// This file "describes" induction, recurrence, and conditional scalar
10+
// assignment (CSA) variables.
1011
//
1112
//===----------------------------------------------------------------------===//
1213

@@ -1570,3 +1571,58 @@ bool InductionDescriptor::isInductionPHI(
15701571
D = InductionDescriptor(StartValue, IK_PtrInduction, Step);
15711572
return true;
15721573
}
1574+
1575+
/// Return CSADescriptor that describes a CSA that matches one of these
1576+
/// patterns:
1577+
/// phi loop_inv, (select cmp, value, phi)
1578+
/// phi loop_inv, (select cmp, phi, value)
1579+
/// phi (select cmp, value, phi), loop_inv
1580+
/// phi (select cmp, phi, value), loop_inv
1581+
/// If the CSA does not match any of these paterns, return a CSADescriptor
1582+
/// that describes an InvalidCSA.
1583+
bool CSADescriptor::isCSAPhi(PHINode *Phi, Loop *TheLoop, CSADescriptor &CSA) {
1584+
1585+
// Must be a scalar.
1586+
Type *Type = Phi->getType();
1587+
if (!Type->isIntegerTy() && !Type->isFloatingPointTy() &&
1588+
!Type->isPointerTy())
1589+
return false;
1590+
1591+
// Match phi loop_inv, (select cmp, value, phi)
1592+
// or phi loop_inv, (select cmp, phi, value)
1593+
// or phi (select cmp, value, phi), loop_inv
1594+
// or phi (select cmp, phi, value), loop_inv
1595+
if (Phi->getNumIncomingValues() != 2)
1596+
return false;
1597+
auto SelectInstIt = find_if(Phi->incoming_values(), [&Phi](const Use &U) {
1598+
return match(U.get(), m_Select(m_Value(), m_Specific(Phi), m_Value())) ||
1599+
match(U.get(), m_Select(m_Value(), m_Value(), m_Specific(Phi)));
1600+
});
1601+
if (SelectInstIt == Phi->incoming_values().end())
1602+
return false;
1603+
auto LoopInvIt = find_if(Phi->incoming_values(), [&](Use &U) {
1604+
return U.get() != *SelectInstIt && TheLoop->isLoopInvariant(U.get());
1605+
});
1606+
if (LoopInvIt == Phi->incoming_values().end())
1607+
return false;
1608+
1609+
// Phi or Sel must be used only outside the loop,
1610+
// excluding if Phi use Sel or Sel use Phi
1611+
auto IsOnlyUsedOutsideLoop = [&](Value *V, Value *Ignore) {
1612+
return all_of(V->users(), [Ignore, TheLoop](User *U) {
1613+
if (U == Ignore)
1614+
return true;
1615+
if (auto *I = dyn_cast<Instruction>(U))
1616+
return !TheLoop->contains(I);
1617+
return true;
1618+
});
1619+
};
1620+
Instruction *Select = cast<SelectInst>(SelectInstIt->get());
1621+
Value *LoopInv = LoopInvIt->get();
1622+
if (!IsOnlyUsedOutsideLoop(Phi, Select) ||
1623+
!IsOnlyUsedOutsideLoop(Select, Phi))
1624+
return false;
1625+
1626+
CSA = CSADescriptor(Phi, Select, LoopInv);
1627+
return true;
1628+
}

Diff for: llvm/lib/Analysis/TargetTransformInfo.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -1373,6 +1373,10 @@ bool TargetTransformInfo::preferEpilogueVectorization() const {
13731373
return TTIImpl->preferEpilogueVectorization();
13741374
}
13751375

1376+
bool TargetTransformInfo::enableCSAVectorization() const {
1377+
return TTIImpl->enableCSAVectorization();
1378+
}
1379+
13761380
TargetTransformInfo::VPLegalization
13771381
TargetTransformInfo::getVPLegalizationStrategy(const VPIntrinsic &VPI) const {
13781382
return TTIImpl->getVPLegalizationStrategy(VPI);

Diff for: llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -2361,6 +2361,11 @@ bool RISCVTTIImpl::isLegalMaskedExpandLoad(Type *DataTy, Align Alignment) {
23612361
return true;
23622362
}
23632363

2364+
bool RISCVTTIImpl::enableCSAVectorization() const {
2365+
return ST->hasVInstructions() &&
2366+
ST->getProcFamily() == RISCVSubtarget::SiFive7;
2367+
}
2368+
23642369
bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
23652370
auto *VTy = dyn_cast<VectorType>(DataTy);
23662371
if (!VTy || VTy->isScalableTy())

Diff for: llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

+4
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,10 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
306306
return TLI->isVScaleKnownToBeAPowerOfTwo();
307307
}
308308

309+
/// \returns true if the loop vectorizer should vectorize conditional
310+
/// scalar assignments for the target.
311+
bool enableCSAVectorization() const;
312+
309313
/// \returns How the target needs this vector-predicated operation to be
310314
/// transformed.
311315
TargetTransformInfo::VPLegalization

Diff for: llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

+31-4
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ static cl::opt<bool> EnableHistogramVectorization(
8383
"enable-histogram-loop-vectorization", cl::init(false), cl::Hidden,
8484
cl::desc("Enables autovectorization of some loops containing histograms"));
8585

86+
static cl::opt<bool>
87+
EnableCSA("enable-csa-vectorization", cl::init(false), cl::Hidden,
88+
cl::desc("Control whether CSA loop vectorization is enabled"));
89+
8690
/// Maximum vectorization interleave count.
8791
static const unsigned MaxInterleaveFactor = 16;
8892

@@ -749,6 +753,15 @@ bool LoopVectorizationLegality::setupOuterLoopInductions() {
749753
return llvm::all_of(Header->phis(), IsSupportedPhi);
750754
}
751755

756+
void LoopVectorizationLegality::addCSAPhi(
757+
PHINode *Phi, const CSADescriptor &CSADesc,
758+
SmallPtrSetImpl<Value *> &AllowedExit) {
759+
assert(CSADesc.isValid() && "Expected Valid CSADescriptor");
760+
LLVM_DEBUG(dbgs() << "LV: found legal CSA opportunity" << *Phi << "\n");
761+
AllowedExit.insert(Phi);
762+
CSAs.insert({Phi, CSADesc});
763+
}
764+
752765
/// Checks if a function is scalarizable according to the TLI, in
753766
/// the sense that it should be vectorized and then expanded in
754767
/// multiple scalar calls. This is represented in the
@@ -866,14 +879,24 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
866879
continue;
867880
}
868881

869-
// As a last resort, coerce the PHI to a AddRec expression
870-
// and re-try classifying it a an induction PHI.
882+
// Try to coerce the PHI to a AddRec expression and re-try classifying
883+
// it a an induction PHI.
871884
if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID, true) &&
872885
!IsDisallowedStridedPointerInduction(ID)) {
873886
addInductionPhi(Phi, ID, AllowedExit);
874887
continue;
875888
}
876889

890+
// Check if the PHI can be classified as a CSA PHI.
891+
if (EnableCSA || (TTI->enableCSAVectorization() &&
892+
EnableCSA.getNumOccurrences() == 0)) {
893+
CSADescriptor CSADesc;
894+
if (CSADescriptor::isCSAPhi(Phi, TheLoop, CSADesc)) {
895+
addCSAPhi(Phi, CSADesc, AllowedExit);
896+
continue;
897+
}
898+
}
899+
877900
reportVectorizationFailure("Found an unidentified PHI",
878901
"value that could not be identified as "
879902
"reduction is used outside the loop",
@@ -1844,11 +1867,15 @@ bool LoopVectorizationLegality::canFoldTailByMasking() const {
18441867
for (const auto &Reduction : getReductionVars())
18451868
ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr());
18461869

1870+
SmallPtrSet<const Value *, 8> CSALiveOuts;
1871+
for (const auto &CSA : getCSAs())
1872+
CSALiveOuts.insert(CSA.second.getAssignment());
1873+
18471874
// TODO: handle non-reduction outside users when tail is folded by masking.
18481875
for (auto *AE : AllowedExit) {
18491876
// Check that all users of allowed exit values are inside the loop or
1850-
// are the live-out of a reduction.
1851-
if (ReductionLiveOuts.count(AE))
1877+
// are the live-out of a reduction or CSA.
1878+
if (ReductionLiveOuts.count(AE) || CSALiveOuts.count(AE))
18521879
continue;
18531880
for (User *U : AE->users()) {
18541881
Instruction *UI = cast<Instruction>(U);

Diff for: llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h

+18-2
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,8 @@ class VPBuilder {
174174
new VPInstruction(Opcode, Operands, WrapFlags, DL, Name));
175175
}
176176

177-
VPValue *createNot(VPValue *Operand, DebugLoc DL = {},
178-
const Twine &Name = "") {
177+
VPInstruction *createNot(VPValue *Operand, DebugLoc DL = {},
178+
const Twine &Name = "") {
179179
return createInstruction(VPInstruction::Not, {Operand}, DL, Name);
180180
}
181181

@@ -257,6 +257,22 @@ class VPBuilder {
257257
FPBinOp ? FPBinOp->getFastMathFlags() : FastMathFlags()));
258258
}
259259

260+
VPInstruction *createCSAMaskPhi(VPValue *InitMask, DebugLoc DL,
261+
const Twine &Name) {
262+
return createInstruction(VPInstruction::CSAMaskPhi, {InitMask}, DL, Name);
263+
}
264+
265+
VPInstruction *createAnyOf(VPValue *Cond, DebugLoc DL, const Twine &Name) {
266+
return createInstruction(VPInstruction::AnyOf, {Cond}, DL, Name);
267+
}
268+
269+
VPInstruction *createCSAMaskSel(VPValue *Cond, VPValue *MaskPhi,
270+
VPValue *AnyOf, DebugLoc DL,
271+
const Twine &Name) {
272+
return createInstruction(VPInstruction::CSAMaskSel, {Cond, MaskPhi, AnyOf},
273+
DL, Name);
274+
}
275+
260276
//===--------------------------------------------------------------------===//
261277
// RAII helpers.
262278
//===--------------------------------------------------------------------===//

0 commit comments

Comments
 (0)