-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][ptr] Add gather
, masked_load
, masked_store
, and scatter
ops
#156368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/fabianmcg/ptr-translation
Are you sure you want to change the base?
[mlir][ptr] Add gather
, masked_load
, masked_store
, and scatter
ops
#156368
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Fabian Mora (fabianmcg) ChangesThis patch adds the
Example: llvm.func @<!-- -->mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>,
%mask: vector<4xi1>, %value: vector<4xf64>, %passthrough: vector<4xf64>) {
%0 = ptr.gather %ptrs, %mask, %passthrough alignment = 8 : vector<4x!ptr.ptr<#llvm.address_space<3>>> -> vector<4xf64>
ptr.scatter %value, %ptrs, %mask alignment = 8 : vector<4xf64>, vector<4x!ptr.ptr<#llvm.address_space<3>>>
%1 = ptr.masked_load %ptr, %mask, %passthrough alignment = 8 : !ptr.ptr<#llvm.address_space<3>> -> vector<4xf64>
ptr.masked_store %value, %ptr, %mask alignment = 8 : vector<4xf64>, !ptr.ptr<#llvm.address_space<3>>
llvm.return
} Translates to: define void @<!-- -->mixed_masked_ops_address_spaces(ptr addrspace(3) %0, <4 x ptr addrspace(3)> %1, <4 x i1> %2, <4 x double> %3, <4 x double> %4) {
%6 = call <4 x double> @<!-- -->llvm.masked.gather.v4f64.v4p3(<4 x ptr addrspace(3)> %1, i32 8, <4 x i1> %2, <4 x double> %4)
call void @<!-- -->llvm.masked.scatter.v4f64.v4p3(<4 x double> %3, <4 x ptr addrspace(3)> %1, i32 8, <4 x i1> %2)
%7 = call <4 x double> @<!-- -->llvm.masked.load.v4f64.p3(ptr addrspace(3) %0, i32 8, <4 x i1> %2, <4 x double> %4)
call void @<!-- -->llvm.masked.store.v4f64.p3(<4 x double> %3, ptr addrspace(3) %0, i32 8, <4 x i1> %2)
ret void
} Patch is 35.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/156368.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 1c88efced950e..170513d57c7be 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -17,6 +17,46 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/OpAsmInterface.td"
+//===----------------------------------------------------------------------===//
+// Common props
+//===----------------------------------------------------------------------===//
+
+def AlignmentProp : OptionalProp<I64Prop>;
+
+//===----------------------------------------------------------------------===//
+// Common types
+//===----------------------------------------------------------------------===//
+
+// A shaped value type with value semantics and rank.
+class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
+ ShapedContainerType<allowedTypes,
+ /*containerPred=*/And<[HasValueSemanticsPred] # preds>,
+ /*descr=*/[{A shaped type with value semantics and rank.}],
+ /*cppType=*/"::mlir::ShapedType">;
+
+// A shaped pointer type with value semantics and rank.
+class Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>;
+
+// A shaped value type of rank 1 of any element type.
+def Ptr_Any1DType :
+ Ptr_ShapedValueType<[AnyType], [HasAnyRankOfPred<[1]>]>;
+
+// A shaped value type of rank 1 of `i1` element type.
+def Ptr_Mask1DType :
+ Ptr_ShapedValueType<[I1], [HasAnyRankOfPred<[1]>]>;
+
+// A shaped value type of rank 1 of `i1` element type.
+def Ptr_Ptr1DType :
+ Ptr_ShapedValueType<[Ptr_PtrType], [HasAnyRankOfPred<[1]>]>;
+
+// Gets the type ID of a type.
+class TypeIDType<string name> :
+ StrFunc<"$" # name # ".getType().getTypeID()">;
+
+// Checks that all type IDs match.
+class AllTypeIDsMatch<list<string> names> :
+ AllMatchSameOperatorTrait<names, TypeIDType<"_self">.result, "type IDs">;
+
//===----------------------------------------------------------------------===//
// FromPtrOp
//===----------------------------------------------------------------------===//
@@ -56,6 +96,58 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// GatherOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_GatherOp : Pointer_Op<"gather", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TypesMatchWith<"result and mask must be compatible", "result", "mask", [{
+ ::llvm::cast<ShapedType>($_self).clone(
+ IntegerType::get($_self.getContext(), 1))
+ }]>,
+ AllTypesMatch<["result", "passthrough"]>,
+ // Check the shapes are compatible and both use the same shaped container
+ // type.
+ AllShapesMatch<["result", "ptrs"]>, AllTypeIDsMatch<["result", "ptrs"]>
+ ]> {
+ let summary = "Gather operation";
+ let description = [{
+ The `gather` operation performs conditional loads from multiple memory
+ locations specified by `ptrs` based on a mask `mask`. Elements of the
+ result corresponding to masked-off lanes are taken from the passthrough
+ operand.
+
+ The mask operand is a shaped type of `i1` elements that must have the same
+ shape as the result type.
+
+ Examples:
+ ```mlir
+ // Gather values from multiple memory locations
+ %result = ptr.gather %ptrs, %mask, %passthrough :
+ vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xf32>
+
+ // Gather with alignment
+ %result = ptr.gather %ptrs, %mask, %passthrough alignment = 8 :
+ vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xf32>
+ ```
+ }];
+ let arguments = (ins Ptr_Ptr1DType:$ptrs,
+ Ptr_Mask1DType:$mask,
+ Ptr_Any1DType:$passthrough,
+ AlignmentProp:$alignment);
+ let results = (outs Ptr_Any1DType:$result);
+ let assemblyFormat = [{
+ $ptrs `,` $mask `,` $passthrough (`alignment` `=` $alignment^)?
+ attr-dict `:` qualified(type($ptrs)) `->` type($result)
+ }];
+ let builders = [
+ OpBuilder<(ins "Type":$resultType, "Value":$ptrs, "Value":$mask,
+ "Value":$passthrough, CArg<"unsigned", "0">:$alignment)>
+ ];
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// GetMetadataOp
//===----------------------------------------------------------------------===//
@@ -122,8 +214,6 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
// LoadOp
//===----------------------------------------------------------------------===//
-def AlignmentProp : OptionalProp<I64Prop>;
-
def Ptr_LoadOp : Pointer_Op<"load", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
@@ -184,6 +274,150 @@ def Ptr_LoadOp : Pointer_Op<"load", [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// MaskedLoadOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_MaskedLoadOp : Pointer_Op<"masked_load", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TypesMatchWith<"result and mask must be compatible", "result", "mask", [{
+ ::llvm::cast<ShapedType>($_self).clone(
+ IntegerType::get($_self.getContext(), 1))
+ }]>,
+ AllTypesMatch<["result", "passthrough"]>
+ ]> {
+ let summary = "Masked load operation";
+ let description = [{
+ The `masked_load` operation performs a conditional load from memory based
+ on a mask. Elements of the result corresponding to masked-off lanes are
+ taken from the passthrough operand.
+
+ The mask operand is a shaped type of `i1` elements that must have the same
+ shape as the result type.
+
+ Examples:
+ ```mlir
+ // Masked load with passthrough on vectors
+ %result = ptr.masked_load %ptr, %mask, %passthrough :
+ !ptr.ptr<#ptr.generic_space> -> vector<4xf32>
+
+ // Masked load with passthrough on tensors
+ %result = ptr.masked_load %ptr, %mask, %passthrough :
+ !ptr.ptr<#ptr.generic_space> -> tensor<4xf32>
+ ```
+ }];
+ let arguments = (ins Ptr_PtrType:$ptr,
+ Ptr_Mask1DType:$mask,
+ Ptr_Any1DType:$passthrough,
+ AlignmentProp:$alignment);
+ let results = (outs Ptr_Any1DType:$result);
+ let assemblyFormat = [{
+ $ptr `,` $mask `,` $passthrough (`alignment` `=` $alignment^)?
+ attr-dict `:` qualified(type($ptr)) `->` type($result)
+ }];
+ let builders = [
+ OpBuilder<(ins "Type":$resultType, "Value":$ptrs, "Value":$mask,
+ "Value":$passthrough, CArg<"unsigned", "0">:$alignment)>
+ ];
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// MaskedStoreOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_MaskedStoreOp : Pointer_Op<"masked_store", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TypesMatchWith<"value and mask must be compatible", "value", "mask", [{
+ ::llvm::cast<ShapedType>($_self).clone(
+ IntegerType::get($_self.getContext(), 1))
+ }]>
+ ]> {
+ let summary = "Masked store operation";
+ let description = [{
+ The `masked_store` operation performs a conditional store to memory based
+ on a mask. Only elements corresponding to set bits in the mask are written
+ to memory.
+
+ The mask operand is a shaped type of `i1` elements that must have the same
+ shape as the value being stored.
+
+ Examples:
+ ```mlir
+ // Masked store
+ ptr.masked_store %value, %ptr, %mask :
+ vector<4xf32>, !ptr.ptr<#ptr.generic_space>
+
+ // Masked store with alignment
+ ptr.masked_store %value, %ptr, %mask alignment = 8 :
+ vector<4xf32>, !ptr.ptr<#ptr.generic_space>
+ ```
+ }];
+
+ let arguments = (ins Ptr_Any1DType:$value,
+ Ptr_PtrType:$ptr,
+ Ptr_Mask1DType:$mask,
+ AlignmentProp:$alignment);
+ let assemblyFormat = [{
+ $value `,` $ptr `,` $mask (`alignment` `=` $alignment^)? attr-dict `:`
+ type($value) `,` qualified(type($ptr))
+ }];
+ let builders = [
+ OpBuilder<(ins "Value":$value, "Value":$ptr, "Value":$mask,
+ CArg<"unsigned", "0">:$alignment)>
+ ];
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// ScatterOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_ScatterOp : Pointer_Op<"scatter", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TypesMatchWith<"value and mask must be compatible", "value", "mask", [{
+ ::llvm::cast<ShapedType>($_self).clone(
+ IntegerType::get($_self.getContext(), 1))
+ }]>,
+ // Check the shapes are compatible and both use the same shaped container
+ // type.
+ AllShapesMatch<["value", "ptrs"]>, AllTypeIDsMatch<["value", "ptrs"]>
+ ]> {
+ let summary = "Scatter operation";
+ let description = [{
+ The `scatter` operation performs a conditional store of a value `value` to
+ multiple memory locations specified by `ptrs` based on a mask `mask`.
+
+ Only elements corresponding to set bits in the mask are written to memory.
+ The mask operand is a shaped type of `i1` elements that must have the same
+ shape as the value being stored.
+
+ Examples:
+ ```mlir
+ // Scatter values to multiple memory locations
+ ptr.scatter %value, %ptrs, %mask :
+ vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>
+
+ // Scatter with alignment
+ ptr.scatter %value, %ptrs, %mask alignment = 8 :
+ vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>
+ ```
+ }];
+ let arguments = (ins Ptr_Any1DType:$value,
+ Ptr_Ptr1DType:$ptrs,
+ Ptr_Mask1DType:$mask,
+ AlignmentProp:$alignment);
+ let assemblyFormat = [{
+ $value `,` $ptrs `,` $mask (`alignment` `=` $alignment^)?
+ attr-dict `:` type($value) `,` qualified(type($ptrs))
+ }];
+ let builders = [
+ OpBuilder<(ins "Value":$value, "Value":$ptrs, "Value":$mask,
+ CArg<"unsigned", "0">:$alignment)>
+ ];
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index 6987926db7e5c..81ae4efd8ec87 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -39,6 +39,23 @@ void PtrDialect::initialize() {
>();
}
+//===----------------------------------------------------------------------===//
+// Common helper functions.
+//===----------------------------------------------------------------------===//
+
+/// Verifies that the alignment attribute is a power of 2 if present.
+static LogicalResult
+verifyAlignment(std::optional<int64_t> alignment,
+ function_ref<InFlightDiagnostic()> emitError) {
+ if (!alignment)
+ return success();
+ if (alignment.value() <= 0)
+ return emitError() << "alignment must be positive";
+ if (!llvm::isPowerOf2_64(alignment.value()))
+ return emitError() << "alignment must be a power of 2";
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// FromPtrOp
//===----------------------------------------------------------------------===//
@@ -84,6 +101,39 @@ LogicalResult FromPtrOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// GatherOp
+//===----------------------------------------------------------------------===//
+
+void GatherOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ // Gather performs reads from multiple memory locations specified by ptrs
+ effects.emplace_back(MemoryEffects::Read::get(), &getPtrsMutable());
+}
+
+LogicalResult GatherOp::verify() {
+ auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+
+ // Verify that the pointer type's memory space allows loads.
+ MemorySpaceAttrInterface ms =
+ cast<PtrType>(getPtrs().getType().getElementType()).getMemorySpace();
+ DataLayout dataLayout = DataLayout::closest(*this);
+ if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic,
+ getAlignment(), dataLayout, emitDiag))
+ return failure();
+
+ // Verify the alignment.
+ return verifyAlignment(getAlignment(), emitDiag);
+}
+
+void GatherOp::build(OpBuilder &builder, OperationState &state, Type resultType,
+ Value ptrs, Value mask, Value passthrough,
+ unsigned alignment) {
+ build(builder, state, resultType, ptrs, mask, passthrough,
+ alignment ? std::optional<int64_t>(alignment) : std::nullopt);
+}
+
//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//
@@ -107,19 +157,6 @@ verifyAtomicMemOp(OpTy memOp, ArrayRef<AtomicOrdering> unsupportedOrderings) {
return success();
}
-/// Verifies that the alignment attribute is a power of 2 if present.
-static LogicalResult
-verifyAlignment(std::optional<int64_t> alignment,
- function_ref<InFlightDiagnostic()> emitError) {
- if (!alignment)
- return success();
- if (alignment.value() <= 0)
- return emitError() << "alignment must be positive";
- if (!llvm::isPowerOf2_64(alignment.value()))
- return emitError() << "alignment must be a power of 2";
- return success();
-}
-
void LoadOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
@@ -158,6 +195,99 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
isVolatile, isNonTemporal, isInvariant, isInvariantGroup, ordering,
syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
}
+//===----------------------------------------------------------------------===//
+// MaskedLoadOp
+//===----------------------------------------------------------------------===//
+
+void MaskedLoadOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ // MaskedLoad performs reads from the memory location specified by ptr.
+ effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable());
+}
+
+LogicalResult MaskedLoadOp::verify() {
+ auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+ // Verify that the pointer type's memory space allows loads.
+ MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
+ DataLayout dataLayout = DataLayout::closest(*this);
+ if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic,
+ getAlignment(), dataLayout, emitDiag))
+ return failure();
+
+ // Verify the alignment.
+ return verifyAlignment(getAlignment(), emitDiag);
+}
+
+void MaskedLoadOp::build(OpBuilder &builder, OperationState &state,
+ Type resultType, Value ptr, Value mask,
+ Value passthrough, unsigned alignment) {
+ build(builder, state, resultType, ptr, mask, passthrough,
+ alignment ? std::optional<int64_t>(alignment) : std::nullopt);
+}
+
+//===----------------------------------------------------------------------===//
+// MaskedStoreOp
+//===----------------------------------------------------------------------===//
+
+void MaskedStoreOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ // MaskedStore performs writes to the memory location specified by ptr
+ effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable());
+}
+
+LogicalResult MaskedStoreOp::verify() {
+ auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+ // Verify that the pointer type's memory space allows stores.
+ MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
+ DataLayout dataLayout = DataLayout::closest(*this);
+ if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic,
+ getAlignment(), dataLayout, emitDiag))
+ return failure();
+
+ // Verify the alignment.
+ return verifyAlignment(getAlignment(), emitDiag);
+}
+
+void MaskedStoreOp::build(OpBuilder &builder, OperationState &state,
+ Value value, Value ptr, Value mask,
+ unsigned alignment) {
+ build(builder, state, value, ptr, mask,
+ alignment ? std::optional<int64_t>(alignment) : std::nullopt);
+}
+
+//===----------------------------------------------------------------------===//
+// ScatterOp
+//===----------------------------------------------------------------------===//
+
+void ScatterOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ // Scatter performs writes to multiple memory locations specified by ptrs
+ effects.emplace_back(MemoryEffects::Write::get(), &getPtrsMutable());
+}
+
+LogicalResult ScatterOp::verify() {
+ auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+
+ // Verify that the pointer type's memory space allows stores.
+ MemorySpaceAttrInterface ms =
+ cast<PtrType>(getPtrs().getType().getElementType()).getMemorySpace();
+ DataLayout dataLayout = DataLayout::closest(*this);
+ if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic,
+ getAlignment(), dataLayout, emitDiag))
+ return failure();
+
+ // Verify the alignment.
+ return verifyAlignment(getAlignment(), emitDiag);
+}
+
+void ScatterOp::build(OpBuilder &builder, OperationState &state, Value value,
+ Value ptrs, Value mask, unsigned alignment) {
+ build(builder, state, value, ptrs, mask,
+ alignment ? std::optional<int64_t>(alignment) : std::nullopt);
+}
//===----------------------------------------------------------------------===//
// StoreOp
diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
index 906e19901617b..ede3d0de90996 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
@@ -207,6 +207,112 @@ convertTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
return success();
}
+/// Convert ptr.gather operation
+static LogicalResult
+convertGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::Value *ptrs = moduleTranslation.lookupValue(gatherOp.getPtrs());
+ llvm::Value *mask = moduleTranslation.lookupValue(gatherOp.getMask());
+ llvm::Value *passthrough =
+ moduleTranslation.lookupValue(gatherOp.getPassthrough());
+
+ if (!ptrs || !mask || !passthrough)
+ return gatherOp.emitError("Failed to lookup operands");
+
+ // Convert result type to LLVM type.
+ llvm::Type *resultType =
+ moduleTranslation.convertType(gatherOp.getResult().getType());
+ if (!resultType)
+ return gatherOp.emitError("Failed to convert result type");
+
+ // Get the alignment.
+ llvm::MaybeAlign alignment(gatherOp.getAlignment().value_or(0));
+
+ // Create the masked gather intrinsic call.
+ llvm::Value *result = builder.CreateMaskedGather(
+ resultType, ptrs, alignment.valueOrOne(), mask, passthrough);
+
+ moduleTranslation.mapValue(gatherOp.getResult(), result);
+ return success();
+}
+
+/// Convert ptr.masked_load operation
+static LogicalResult
+convertMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::Value *ptr = moduleTranslation.lookupValue(maskedLoadOp.getPtr());
+ llvm::Value *mask = moduleTranslation.lookupValue(maskedLoadOp.getMask());
+ llvm::Value *passthrough =
+ moduleTranslation.lookupValue(maskedLoadOp.getPassthrough());
+
+ if (!ptr || !mask || !passthrough)
+ return maskedLoadOp.emitError("Failed to lookup operands");
+
+ // Conver...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds four new operations to the ptr
dialect for masked memory operations: gather
, masked_load
, masked_store
, and scatter
. These operations enable conditional memory access patterns commonly used in SIMD/vectorized code, with corresponding translations to LLVM masked memory intrinsics.
- Addition of four new memory operations with mask-based conditional access
- Implementation of LLVM IR translation for these operations to corresponding LLVM intrinsics
- Comprehensive test coverage for various data types, alignments, and address spaces
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
Show a summary per file
File | Description |
---|---|
mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | Defines the four new operations with their syntax, constraints, and documentation |
mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | Implements verification logic, side effects, and builder methods for the new operations |
mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp | Adds translation functions to convert the new operations to LLVM masked memory intrinsics |
mlir/test/Dialect/Ptr/ops.mlir | Tests operation parsing and verification with various type combinations |
mlir/test/Target/LLVMIR/ptr.mlir | Tests LLVM IR translation output for the new operations |
|
StrFunc<"$" # name # ".getType().getTypeID()">; | ||
|
||
// Checks that all type IDs match. | ||
class AllTypeIDsMatch<list<string> names> : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: does not seem like a ptr-specific helper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't add it to OpBase.td
because this is the only consumer. I'll move it to OpBase.td
.
let results = (outs Ptr_Any1DType:$result); | ||
let assemblyFormat = [{ | ||
$ptrs `,` $mask `,` $passthrough (`alignment` `=` $alignment^)? | ||
attr-dict `:` qualified(type($ptrs)) `->` type($result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the qualified
trying to address here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I generally prefer qualified type, but I'll remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was mostly trying to figure out in this case what is the impact? Because the examples show only vector
which isn't impacted by this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, I'll remove.
17a5113
to
287c3b2
Compare
This patch adds the
gather
,masked_load
,masked_store
, andscatter
operations to theptr
dialect. It also implements translation from these operations to LLVM intrinsics:Example:
Translates to: