Skip to content

Commit

Permalink
Merge pull request #266 from frasercrmck/vecz-masked-cmpxhg
Browse files Browse the repository at this point in the history
[vecz] Add support for masking cmpxchg instructions
  • Loading branch information
frasercrmck authored Dec 20, 2023
2 parents 1e1b744 + db755a2 commit d77f61e
Show file tree
Hide file tree
Showing 12 changed files with 845 additions and 350 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ void UniformValueResult::findVectorLeaves(
Op->isMaskedScatterGatherMemOp())) {
IsCallLeaf = true;
}
} else if (Ctx.isMaskedAtomicFunction(*CI->getCalledFunction())) {
IsCallLeaf = true;
}
if (IsCallLeaf) {
Leaves.push_back(CI);
Expand Down
46 changes: 34 additions & 12 deletions modules/compiler/vecz/source/include/vectorization_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,36 +153,59 @@ class VectorizationContext {
/// @return The masked version of the function
llvm::Function *getOrCreateMaskedFunction(llvm::CallInst *CI);

struct MaskedAtomicRMW {
/// @brief Represents either an atomicrmw or cmpxchg operation.
///
/// Most fields are shared, with the exception of CmpXchgFailureOrdering and
/// IsWeak, which are only to be set for cmpxchg, and BinOp, which is only to
/// be set to a valid value for atomicrmw.
struct MaskedAtomic {
llvm::Type *PointerTy;
llvm::Type *ValTy;
/// @brief Must be set to BAD_BINOP for cmpxchg instructions
llvm::AtomicRMWInst::BinOp BinOp;
llvm::Align Align;
bool IsVolatile = false;
llvm::SyncScope::ID SyncScope;
llvm::AtomicOrdering Ordering;
/// @brief Must be set for cmpxchg instructions
std::optional<llvm::AtomicOrdering> CmpXchgFailureOrdering = std::nullopt;
/// @brief Must only be set for cmpxchg instructions
bool IsWeak = false;
// Vectorization info
llvm::ElementCount VF;
bool IsVectorPredicated = false;

/// @brief Returns true if this MaskedAtomic represents a cmpxchg operation.
bool isCmpXchg() const {
if (CmpXchgFailureOrdering.has_value()) {
// 'binop' only applies to atomicrmw
assert(BinOp == llvm::AtomicRMWInst::BAD_BINOP &&
"Invalid MaskedAtomic state");
return true;
}
// 'weak' only applies to cmpxchg
assert(!IsWeak && "Invalid MaskedAtomic state");
return false;
}
};

/// @brief Check if the given function is a masked version of an atomic RMW
/// operation.
/// @brief Check if the given function is a masked version of an atomicrmw or
/// cmpxchg operation.
///
/// @param[in] F The function to check
/// @return A MaskedAtomicRMW instance detailing the atomic operation if the
/// function is a masked atomic RMW, or std::nullopt otherwise
std::optional<MaskedAtomicRMW> isMaskedAtomicRMWFunction(
/// @return A MaskedAtomic instance detailing the atomic operation if the
/// function is a masked atomic, or std::nullopt otherwise
std::optional<MaskedAtomic> isMaskedAtomicFunction(
const llvm::Function &F) const;
/// @brief Get (if it exists already) or create the function representing the
/// masked version of an atomic RMW operation.
/// masked version of an atomicrmw/cmpxchg operation.
///
/// @param[in] I Atomic to be masked
/// @param[in] Choices Choices to mangle into the function name
/// @param[in] VF The vectorization factor of the atomic operation
/// @return The masked version of the function
llvm::Function *getOrCreateMaskedAtomicRMWFunction(
MaskedAtomicRMW &I, const VectorizationChoices &Choices,
llvm::Function *getOrCreateMaskedAtomicFunction(
MaskedAtomic &I, const VectorizationChoices &Choices,
llvm::ElementCount VF);

/// @brief Create a VectorizationUnit to use to vectorize the given scalar
Expand Down Expand Up @@ -296,10 +319,9 @@ class VectorizationContext {
/// @brief Emit the body for a masked atomic builtin
///
/// @param[in] F The empty (declaration only) function to emit the body in
/// @param[in] MA The MaskedAtomicRMW information
/// @param[in] MA The MaskedAtomic information
/// @returns true on success, false otherwise
bool emitMaskedAtomicRMWBody(llvm::Function &F,
const MaskedAtomicRMW &MA) const;
bool emitMaskedAtomicBody(llvm::Function &F, const MaskedAtomic &MA) const;

/// @brief Helper for non-vectorization tasks.
TargetInfo &VTI;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,14 @@ class ControlFlowConversionState::Impl : public ControlFlowConversionState {
/// @return true if it is valid to mask this call, false otherwise
bool applyMaskToCall(CallInst *CI, Value *mask, DeletionMap &toDelete);

/// @brief Attempt to apply a mask to an AtomicRMW instruction via a builtin
/// @brief Attempt to apply a mask to an atomic instruction via a builtin
/// call.
///
/// @param[in] atomicI The atomic instruction to apply the mask to
/// @param[in] I The (atomic) instruction to apply the mask to
/// @param[in] mask The mask to apply to the masked atomic
/// @param[out] toDelete mapping of deleted unmasked operations
/// @return true if it is valid to mask this atomic, false otherwise
bool applyMaskToAtomicRMW(AtomicRMWInst &atomicI, Value *mask,
DeletionMap &toDelete);
bool applyMaskToAtomic(Instruction &I, Value *mask, DeletionMap &toDelete);

/// @brief Linearize a CFG.
/// @return true if no problem occurred, false otherwise.
Expand Down Expand Up @@ -1138,9 +1137,7 @@ Error ControlFlowConversionState::Impl::applyMask(BasicBlock &BB, Value *mask) {
}
} else if (I.isAtomic() && !isa<FenceInst>(&I)) {
// Turn atomics into calls to masked builtins if possible.
// FIXME: We don't yet support masked cmpxchg instructions.
if (auto *atomicI = dyn_cast<AtomicRMWInst>(&I);
!atomicI || !applyMaskToAtomicRMW(*atomicI, mask, toDelete)) {
if (!applyMaskToAtomic(I, mask, toDelete)) {
return makeStringError("Could not apply mask to atomic instruction", I);
}
} else if (auto *branch = dyn_cast<BranchInst>(&I)) {
Expand Down Expand Up @@ -1372,41 +1369,66 @@ bool ControlFlowConversionState::Impl::applyMaskToCall(CallInst *CI,
return true;
}

bool ControlFlowConversionState::Impl::applyMaskToAtomicRMW(
AtomicRMWInst &atomicI, Value *mask, DeletionMap &toDelete) {
LLVM_DEBUG(dbgs() << "vecz-cf: Now at AtomicRMWInst " << atomicI << "\n");
bool ControlFlowConversionState::Impl::applyMaskToAtomic(
Instruction &I, Value *mask, DeletionMap &toDelete) {
LLVM_DEBUG(dbgs() << "vecz-cf: Now at atomic inst " << I << "\n");

VectorizationContext::MaskedAtomicRMW MA;
MA.Align = atomicI.getAlign();
MA.BinOp = atomicI.getOperation();
MA.IsVectorPredicated = VU.choices().vectorPredication();
MA.IsVolatile = atomicI.isVolatile();
MA.Ordering = atomicI.getOrdering();
MA.SyncScope = atomicI.getSyncScopeID();
SmallVector<Value *, 8> maskedFnArgs;
VectorizationContext::MaskedAtomic MA;
MA.VF = ElementCount::getFixed(1);
MA.ValTy = atomicI.getType();
MA.PointerTy = atomicI.getPointerOperand()->getType();
MA.IsVectorPredicated = VU.choices().vectorPredication();

if (auto *atomicI = dyn_cast<AtomicRMWInst>(&I)) {
MA.Align = atomicI->getAlign();
MA.BinOp = atomicI->getOperation();
MA.IsVolatile = atomicI->isVolatile();
MA.Ordering = atomicI->getOrdering();
MA.SyncScope = atomicI->getSyncScopeID();
MA.ValTy = atomicI->getType();
MA.PointerTy = atomicI->getPointerOperand()->getType();

// Set up the arguments to this function
maskedFnArgs = {atomicI->getPointerOperand(), atomicI->getValOperand(),
mask};

} else if (auto *cmpxchgI = dyn_cast<AtomicCmpXchgInst>(&I)) {
MA.Align = cmpxchgI->getAlign();
MA.BinOp = AtomicRMWInst::BAD_BINOP;
MA.IsWeak = cmpxchgI->isWeak();
MA.IsVolatile = cmpxchgI->isVolatile();
MA.Ordering = cmpxchgI->getSuccessOrdering();
MA.CmpXchgFailureOrdering = cmpxchgI->getFailureOrdering();
MA.SyncScope = cmpxchgI->getSyncScopeID();
MA.ValTy = cmpxchgI->getCompareOperand()->getType();
MA.PointerTy = cmpxchgI->getPointerOperand()->getType();

// Set up the arguments to this function
maskedFnArgs = {cmpxchgI->getPointerOperand(),
cmpxchgI->getCompareOperand(), cmpxchgI->getNewValOperand(),
mask};
} else {
return false;
}

// Create the new function and replace the old one with it
// Get the masked function
Function *newFunction = Ctx.getOrCreateMaskedAtomicRMWFunction(
Function *maskedAtomicFn = Ctx.getOrCreateMaskedAtomicFunction(
MA, VU.choices(), ElementCount::getFixed(1));
VECZ_FAIL_IF(!newFunction);
SmallVector<Value *, 8> fnArgs = {atomicI.getPointerOperand(),
atomicI.getValOperand(), mask};
VECZ_FAIL_IF(!maskedAtomicFn);
// We don't have a vector length just yet - pass in one as a dummy.
if (MA.IsVectorPredicated) {
fnArgs.push_back(
ConstantInt::get(IntegerType::getInt32Ty(atomicI.getContext()), 1));
maskedFnArgs.push_back(
ConstantInt::get(IntegerType::getInt32Ty(I.getContext()), 1));
}

CallInst *newCI = CallInst::Create(newFunction, fnArgs, "", &atomicI);
VECZ_FAIL_IF(!newCI);
CallInst *maskedCI = CallInst::Create(maskedAtomicFn, maskedFnArgs, "", &I);
VECZ_FAIL_IF(!maskedCI);

atomicI.replaceAllUsesWith(newCI);
toDelete.emplace_back(&atomicI, newCI);
I.replaceAllUsesWith(maskedCI);
toDelete.emplace_back(&I, maskedCI);

LLVM_DEBUG(dbgs() << "vecz-cf: Replaced " << atomicI << "\n");
LLVM_DEBUG(dbgs() << " with " << *newCI << "\n");
LLVM_DEBUG(dbgs() << "vecz-cf: Replaced " << I << "\n");
LLVM_DEBUG(dbgs() << " with " << *maskedCI << "\n");

return true;
}
Expand Down
17 changes: 16 additions & 1 deletion modules/compiler/vecz/source/transform/packetization_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <compiler/utils/group_collective_helpers.h>
#include <llvm/ADT/Twine.h>
#include <llvm/Analysis/VectorUtils.h>
#include <llvm/IR/Constants.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/Intrinsics.h>
Expand All @@ -46,6 +47,18 @@ using namespace vecz;
namespace {
inline Type *getWideType(Type *ty, ElementCount factor) {
if (!ty->isVectorTy()) {
// The wide type of a struct literal is the wide type of each of its
// elements.
if (auto *structTy = dyn_cast<StructType>(ty);
structTy && structTy->isLiteral()) {
SmallVector<Type *, 4> wideElts(structTy->elements());
for (unsigned i = 0, e = wideElts.size(); i != e; i++) {
wideElts[i] = getWideType(wideElts[i], factor);
}
return StructType::get(ty->getContext(), wideElts);
} else if (structTy) {
VECZ_ERROR("Can't create wide type for structure type");
}
return VectorType::get(ty, factor);
}
bool const isScalable = isa<ScalableVectorType>(ty);
Expand Down Expand Up @@ -694,7 +707,9 @@ const Packetizer::Result &Packetizer::Result::broadcast(unsigned width) const {
auto &F = packetizer.F;
Value *result = nullptr;
const auto &TI = packetizer.context().targetInfo();
if (isa<UndefValue>(scalar)) {
if (isa<PoisonValue>(scalar)) {
result = PoisonValue::get(getWideType(ty, factor));
} else if (isa<UndefValue>(scalar)) {
result = UndefValue::get(getWideType(ty, factor));
} else if (ty->isVectorTy() && factor.isScalable()) {
IRBuilder<> B(buildAfter(scalar, F));
Expand Down
Loading

0 comments on commit d77f61e

Please sign in to comment.