Skip to content

Commit

Permalink
[mlir][Transforms] Merge 1:1 and 1:N type converters (#113032)
Browse files Browse the repository at this point in the history
The 1:N type converter derived from the 1:1 type converter and extends
it with 1:N target materializations. This commit merges the two type
converters and stores 1:N target materializations in the 1:1 type
converter. This is in preparation of merging the 1:1 and 1:N dialect
conversion infrastructures.

1:1 target materializations (producing a single `Value`) will remain
valid. An additional API is added to the type converter to register 1:N
target materializations (producing a `SmallVector<Value>`). Internally,
all target materializations are stored as 1:N materializations.

The 1:N type converter is removed.

Note for LLVM integration: If you are using the `OneToNTypeConverter`,
simply switch all occurrences to `TypeConverter`.

---------

Co-authored-by: Markus Böck <[email protected]>
  • Loading branch information
matthias-springer and zero9178 authored Oct 25, 2024
1 parent 9648271 commit 8c4bc1e
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 98 deletions.
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ std::unique_ptr<Pass> createLowerForeachToSCFPass();
//===----------------------------------------------------------------------===//

/// Type converter for iter_space and iterator.
struct SparseIterationTypeConverter : public OneToNTypeConverter {
struct SparseIterationTypeConverter : public TypeConverter {
SparseIterationTypeConverter();
};

Expand Down
62 changes: 48 additions & 14 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ class TypeConverter {
/// conversion has finished.
///
/// Note: Target materializations may optionally accept an additional Type
/// parameter, which is the original type of the SSA value.
/// parameter, which is the original type of the SSA value. Furthermore, `T`
/// can be a TypeRange; in that case, the function must return a
/// SmallVector<Value>.

/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
Expand Down Expand Up @@ -210,6 +212,9 @@ class TypeConverter {
/// will be invoked with: outputType = "t3", inputs = "v2",
// originalType = "t1". Note that the original type "t1" cannot be recovered
/// from just "t3" and "v2"; that's why the originalType parameter exists.
///
/// Note: During a 1:N conversion, the result types can be a TypeRange. In
/// that case the materialization produces a SmallVector<Value>.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addTargetMaterialization(FnT &&callback) {
Expand Down Expand Up @@ -316,6 +321,11 @@ class TypeConverter {
Value materializeTargetConversion(OpBuilder &builder, Location loc,
Type resultType, ValueRange inputs,
Type originalType = {}) const;
SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
Location loc,
TypeRange resultType,
ValueRange inputs,
Type originalType = {}) const;

/// Convert an attribute present `attr` from within the type `type` using
/// the registered conversion functions. If no applicable conversion has been
Expand All @@ -340,9 +350,9 @@ class TypeConverter {

/// The signature of the callback used to materialize a target conversion.
///
/// Arguments: builder, result type, inputs, location, original type
using TargetMaterializationCallbackFn =
std::function<Value(OpBuilder &, Type, ValueRange, Location, Type)>;
/// Arguments: builder, result types, inputs, location, original type
using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
OpBuilder &, TypeRange, ValueRange, Location, Type)>;

/// The signature of the callback used to convert a type attribute.
using TypeAttributeConversionCallbackFn =
Expand Down Expand Up @@ -409,32 +419,56 @@ class TypeConverter {
/// callback.
///
/// With callback of form:
/// `Value(OpBuilder &, T, ValueRange, Location, Type)`
/// - Value(OpBuilder &, T, ValueRange, Location, Type)
/// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
template <typename T, typename FnT>
std::enable_if_t<
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
TargetMaterializationCallbackFn>
wrapTargetMaterialization(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
OpBuilder &builder, Type resultType, ValueRange inputs,
Location loc, Type originalType) -> Value {
if (T derivedType = dyn_cast<T>(resultType))
return callback(builder, derivedType, inputs, loc, originalType);
return Value();
OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
Location loc, Type originalType) -> SmallVector<Value> {
SmallVector<Value> result;
if constexpr (std::is_same<T, TypeRange>::value) {
// This is a 1:N target materialization. Return the produces values
// directly.
result = callback(builder, resultTypes, inputs, loc, originalType);
} else if constexpr (std::is_assignable<Type, T>::value) {
// This is a 1:1 target materialization. Invoke the callback only if a
// single SSA value is requested.
if (resultTypes.size() == 1) {
// Invoke the callback only if the type class of the callback matches
// the requested result type.
if (T derivedType = dyn_cast<T>(resultTypes.front())) {
// 1:1 materializations produce single values, but we store 1:N
// target materialization functions in the type converter. Wrap the
// result value in a SmallVector<Value>.
Value val =
callback(builder, derivedType, inputs, loc, originalType);
if (val)
result.push_back(val);
}
}
} else {
static_assert(sizeof(T) == 0, "T must be a Type or a TypeRange");
}
return result;
};
}
/// With callback of form:
/// `Value(OpBuilder &, T, ValueRange, Location)`
/// - Value(OpBuilder &, T, ValueRange, Location)
/// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
template <typename T, typename FnT>
std::enable_if_t<
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
TargetMaterializationCallbackFn>
wrapTargetMaterialization(FnT &&callback) const {
return wrapTargetMaterialization<T>(
[callback = std::forward<FnT>(callback)](
OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
Type originalType) -> Value {
return callback(builder, resultType, inputs, loc);
OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
Type originalType) {
return callback(builder, resultTypes, inputs, loc);
});
}

Expand Down
45 changes: 1 addition & 44 deletions mlir/include/mlir/Transforms/OneToNTypeConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,49 +33,6 @@

namespace mlir {

/// Extends `TypeConverter` with 1:N target materializations. Such
/// materializations have to provide the "reverse" of 1:N type conversions,
/// i.e., they need to materialize N values with target types into one value
/// with a source type (which isn't possible in the base class currently).
class OneToNTypeConverter : public TypeConverter {
public:
/// Callback that expresses user-provided materialization logic from the given
/// value to N values of the given types. This is useful for expressing target
/// materializations for 1:N type conversions, which materialize one value in
/// a source type as N values in target types.
using OneToNMaterializationCallbackFn =
std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
Value, Location)>;

/// Creates the mapping of the given range of original types to target types
/// of the conversion and stores that mapping in the given (signature)
/// conversion. This function simply calls
/// `TypeConverter::convertSignatureArgs` and exists here with a different
/// name to reflect the broader semantic.
LogicalResult computeTypeMapping(TypeRange types,
SignatureConversion &result) const {
return convertSignatureArgs(types, result);
}

/// Applies one of the user-provided 1:N target materializations. If several
/// exists, they are tried out in the reverse order in which they have been
/// added until the first one succeeds. If none succeeds, the functions
/// returns `std::nullopt`.
std::optional<SmallVector<Value>>
materializeTargetConversion(OpBuilder &builder, Location loc,
TypeRange resultTypes, Value input) const;

/// Adds a 1:N target materialization to the converter. Such materializations
/// build IR that converts N values with target types into 1 value of the
/// source type.
void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) {
oneToNTargetMaterializations.emplace_back(std::move(callback));
}

private:
SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
};

/// Stores a 1:N mapping of types and provides several useful accessors. This
/// class extends `SignatureConversion`, which already supports 1:N type
/// mappings but lacks some accessors into the mapping as well as access to the
Expand Down Expand Up @@ -295,7 +252,7 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
/// not fail if some ops or types remain unconverted (i.e., the conversion is
/// only "partial").
LogicalResult
applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
const FrozenRewritePatternSet &patterns);

/// Add a pattern to the given pattern list to convert the signature of a
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
auto *context = &getContext();
OneToNTypeConverter converter;
TypeConverter converter;
RewritePatternSet patterns(context);
converter.addConversion([](Type type) { return type; });
converter.addConversion(
Expand Down
26 changes: 22 additions & 4 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2831,11 +2831,29 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
Location loc, Type resultType,
ValueRange inputs,
Type originalType) const {
SmallVector<Value> result = materializeTargetConversion(
builder, loc, TypeRange(resultType), inputs, originalType);
if (result.empty())
return nullptr;
assert(result.size() == 1 && "expected single result");
return result.front();
}

SmallVector<Value> TypeConverter::materializeTargetConversion(
OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
Type originalType) const {
for (const TargetMaterializationCallbackFn &fn :
llvm::reverse(targetMaterializations))
if (Value result = fn(builder, resultType, inputs, loc, originalType))
return result;
return nullptr;
llvm::reverse(targetMaterializations)) {
SmallVector<Value> result =
fn(builder, resultTypes, inputs, loc, originalType);
if (result.empty())
continue;
assert(TypeRange(result) == resultTypes &&
"callback produced incorrect number of values or values with "
"incorrect types");
return result;
}
return {};
}

std::optional<TypeConverter::SignatureConversion>
Expand Down
44 changes: 14 additions & 30 deletions mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,6 @@
using namespace llvm;
using namespace mlir;

std::optional<SmallVector<Value>>
OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
Location loc,
TypeRange resultTypes,
Value input) const {
for (const OneToNMaterializationCallbackFn &fn :
llvm::reverse(oneToNTargetMaterializations)) {
if (std::optional<SmallVector<Value>> result =
fn(builder, resultTypes, input, loc))
return *result;
}
return std::nullopt;
}

TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
TypeRange convertedTypes = getConvertedTypes();
if (auto mapping = getInputMapping(originalTypeNo))
Expand Down Expand Up @@ -268,20 +254,20 @@ Block *OneToNPatternRewriter::applySignatureConversion(
LogicalResult
OneToNConversionPattern::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
auto *typeConverter = getTypeConverter();

// Construct conversion mapping for results.
Operation::result_type_range originalResultTypes = op->getResultTypes();
OneToNTypeMapping resultMapping(originalResultTypes);
if (failed(typeConverter->computeTypeMapping(originalResultTypes,
resultMapping)))
if (failed(typeConverter->convertSignatureArgs(originalResultTypes,
resultMapping)))
return failure();

// Construct conversion mapping for operands.
Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
OneToNTypeMapping operandMapping(originalOperandTypes);
if (failed(typeConverter->computeTypeMapping(originalOperandTypes,
operandMapping)))
if (failed(typeConverter->convertSignatureArgs(originalOperandTypes,
operandMapping)))
return failure();

// Cast operands to target types.
Expand Down Expand Up @@ -318,7 +304,7 @@ namespace mlir {
// inserted by this pass are annotated with a string attribute that also
// documents which kind of the cast (source, argument, or target).
LogicalResult
applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
const FrozenRewritePatternSet &patterns) {
#ifndef NDEBUG
// Remember existing unrealized casts. This data structure is only used in
Expand Down Expand Up @@ -370,15 +356,13 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
// Target materialization.
assert(!areOperandTypesLegal && areResultsTypesLegal &&
operands.size() == 1 && "found unexpected target cast");
std::optional<SmallVector<Value>> maybeResults =
typeConverter.materializeTargetConversion(
rewriter, castOp->getLoc(), resultTypes, operands.front());
if (!maybeResults) {
materializedResults = typeConverter.materializeTargetConversion(
rewriter, castOp->getLoc(), resultTypes, operands.front());
if (materializedResults.empty()) {
emitError(castOp->getLoc())
<< "failed to create target materialization";
return failure();
}
materializedResults = maybeResults.value();
} else {
// Source and argument materializations.
assert(areOperandTypesLegal && !areResultsTypesLegal &&
Expand Down Expand Up @@ -427,18 +411,18 @@ class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern {
const OneToNTypeMapping &resultMapping,
ValueRange convertedOperands) const override {
auto funcOp = cast<FunctionOpInterface>(op);
auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
auto *typeConverter = getTypeConverter();

// Construct mapping for function arguments.
OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes());
if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(),
argumentMapping)))
if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(),
argumentMapping)))
return failure();

// Construct mapping for function results.
OneToNTypeMapping funcResultMapping(funcOp.getResultTypes());
if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(),
funcResultMapping)))
if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(),
funcResultMapping)))
return failure();

// Nothing to do if the op doesn't have any non-identity conversions for its
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,14 @@ populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter,
///
/// This function has been copied (with small adaptions) from
/// TestDecomposeCallGraphTypes.cpp.
static std::optional<SmallVector<Value>>
buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
Location loc) {
static SmallVector<Value> buildGetTupleElementOps(OpBuilder &builder,
TypeRange resultTypes,
ValueRange inputs,
Location loc) {
if (inputs.size() != 1)
return {};
Value input = inputs.front();

TupleType inputType = dyn_cast<TupleType>(input.getType());
if (!inputType)
return {};
Expand Down Expand Up @@ -222,7 +227,7 @@ void TestOneToNTypeConversionPass::runOnOperation() {
auto *context = &getContext();

// Assemble type converter.
OneToNTypeConverter typeConverter;
TypeConverter typeConverter;

typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion(
Expand All @@ -234,6 +239,11 @@ void TestOneToNTypeConversionPass::runOnOperation() {
typeConverter.addArgumentMaterialization(buildMakeTupleOp);
typeConverter.addSourceMaterialization(buildMakeTupleOp);
typeConverter.addTargetMaterialization(buildGetTupleElementOps);
// Test the other target materialization variant that takes the original type
// as additional argument. This materialization function always fails.
typeConverter.addTargetMaterialization(
[](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
Location loc, Type originalType) -> SmallVector<Value> { return {}; });

// Assemble patterns.
RewritePatternSet patterns(context);
Expand Down

0 comments on commit 8c4bc1e

Please sign in to comment.