Skip to content

Commit d8b84be

Browse files
authored
[MLIR][Transform][SMT] Introduce transform.smt.constrain_params (llvm#159450)
Introduces a Transform-dialect SMT-extension so that we can have an op to express constrains on Transform-dialect params, in particular when these params are knobs -- see transform.tune.knob -- and can hence be seen as symbolic variables. This op allows expressing joint constraints over multiple params/knobs together. While the op's semantics are clearly defined, per SMTLIB, the interpreted semantics -- i.e. the `apply()` method -- for now just defaults to failure. In the future we should support attaching an implementation so that users can Bring Your Own Solver and thereby control performance of interpreting the op. For now the main usage is to walk schedule IR and collect these constraints so that knobs can be rewritten to constants that satisfy the constraints.
1 parent c12f08f commit d8b84be

File tree

18 files changed

+461
-10
lines changed

18 files changed

+461
-10
lines changed

mlir/include/mlir/Dialect/Transform/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@ add_subdirectory(IR)
44
add_subdirectory(IRDLExtension)
55
add_subdirectory(LoopExtension)
66
add_subdirectory(PDLExtension)
7+
add_subdirectory(SMTExtension)
78
add_subdirectory(Transforms)
89
add_subdirectory(TuneExtension)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(LLVM_TARGET_DEFINITIONS SMTExtensionOps.td)
2+
mlir_tablegen(SMTExtensionOps.h.inc -gen-op-decls)
3+
mlir_tablegen(SMTExtensionOps.cpp.inc -gen-op-defs)
4+
add_public_tablegen_target(MLIRTransformDialectSMTExtensionOpsIncGen)
5+
6+
add_mlir_doc(SMTExtensionOps SMTExtensionOps Dialects/ -gen-op-doc)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- SMTExtension.h - SMT extension for Transform dialect -----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H
10+
#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H
11+
12+
#include "mlir/Bytecode/BytecodeOpInterface.h"
13+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
14+
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
15+
#include "mlir/IR/OpDefinition.h"
16+
#include "mlir/IR/OpImplementation.h"
17+
18+
namespace mlir {
19+
class DialectRegistry;
20+
21+
namespace transform {
22+
/// Registers the SMT extension of the Transform dialect in the given registry.
23+
void registerSMTExtension(DialectRegistry &dialectRegistry);
24+
} // namespace transform
25+
} // namespace mlir
26+
27+
#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- SMTExtensionOps.h - SMT extension for Transform dialect --*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
10+
#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
11+
12+
#include "mlir/Bytecode/BytecodeOpInterface.h"
13+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
14+
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
15+
#include "mlir/IR/OpDefinition.h"
16+
#include "mlir/IR/OpImplementation.h"
17+
18+
#define GET_OP_CLASSES
19+
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h.inc"
20+
21+
#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
//===- SMTExtensionOps.td - Transform dialect operations ---*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS
10+
#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS
11+
12+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
13+
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
14+
include "mlir/Interfaces/SideEffectInterfaces.td"
15+
16+
def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
17+
DeclareOpInterfaceMethods<TransformOpInterface>,
18+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
19+
NoTerminator
20+
]> {
21+
let cppNamespace = [{ mlir::transform::smt }];
22+
23+
let summary = "Express contraints on params interpreted as symbolic values";
24+
let description = [{
25+
Allows expressing constraints on params using the SMT dialect.
26+
27+
Each Transform dialect param provided as an operand has a corresponding
28+
argument of SMT-type in the region. The SMT-Dialect ops in the region use
29+
these arguments as operands.
30+
31+
The semantics of this op is that all the ops in the region together express
32+
a constraint on the params-interpreted-as-smt-vars. The op fails in case the
33+
expressed constraint is not satisfiable per SMTLIB semantics. Otherwise the
34+
op succeeds.
35+
36+
---
37+
38+
TODO: currently the operational semantics per the Transform interpreter is
39+
to always fail. The intention is build out support for hooking in your own
40+
operational semantics so you can invoke your favourite solver to determine
41+
satisfiability of the corresponding constraint problem.
42+
}];
43+
44+
let arguments = (ins Variadic<TransformParamTypeInterface>:$params);
45+
let regions = (region SizedRegion<1>:$body);
46+
let assemblyFormat =
47+
"`(` $params `)` attr-dict `:` type(operands) $body";
48+
49+
let hasVerifier = 1;
50+
}
51+
52+
#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS

mlir/lib/Bindings/Python/DialectSMT.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,26 @@ using namespace mlir::python::nanobind_adaptors;
2626

2727
static void populateDialectSMTSubmodule(nanobind::module_ &m) {
2828

29-
auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
30-
.def_classmethod(
31-
"get",
32-
[](const nb::object &, MlirContext context) {
33-
return mlirSMTTypeGetBool(context);
34-
},
35-
"cls"_a, "context"_a = nb::none());
29+
auto smtBoolType =
30+
mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
31+
.def_staticmethod(
32+
"get",
33+
[](MlirContext context) { return mlirSMTTypeGetBool(context); },
34+
"context"_a = nb::none());
3635
auto smtBitVectorType =
3736
mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector)
38-
.def_classmethod(
37+
.def_staticmethod(
3938
"get",
40-
[](const nb::object &, int32_t width, MlirContext context) {
39+
[](int32_t width, MlirContext context) {
4140
return mlirSMTTypeGetBitVector(context, width);
4241
},
43-
"cls"_a, "width"_a, "context"_a = nb::none());
42+
"width"_a, "context"_a = nb::none());
43+
auto smtIntType =
44+
mlir_type_subclass(m, "IntType", mlirSMTTypeIsAInt)
45+
.def_staticmethod(
46+
"get",
47+
[](MlirContext context) { return mlirSMTTypeGetInt(context); },
48+
"context"_a = nb::none());
4449

4550
auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues,
4651
bool indentLetBody) {

mlir/lib/Dialect/Transform/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_subdirectory(IR)
44
add_subdirectory(IRDLExtension)
55
add_subdirectory(LoopExtension)
66
add_subdirectory(PDLExtension)
7+
add_subdirectory(SMTExtension)
78
add_subdirectory(Transforms)
89
add_subdirectory(TuneExtension)
910
add_subdirectory(Utils)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
add_mlir_dialect_library(MLIRTransformSMTExtension
2+
SMTExtension.cpp
3+
SMTExtensionOps.cpp
4+
5+
DEPENDS
6+
MLIRTransformDialectSMTExtensionOpsIncGen
7+
8+
LINK_LIBS PUBLIC
9+
MLIRIR
10+
MLIRTransformDialect
11+
MLIRSMT
12+
)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//===- SMTExtension.cpp - SMT extension for the Transform dialect ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
10+
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
11+
#include "mlir/IR/DialectRegistry.h"
12+
13+
using namespace mlir;
14+
15+
//===----------------------------------------------------------------------===//
16+
// Transform op registration
17+
//===----------------------------------------------------------------------===//
18+
19+
namespace {
20+
class SMTExtension : public transform::TransformDialectExtension<SMTExtension> {
21+
public:
22+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SMTExtension)
23+
24+
SMTExtension() {
25+
registerTransformOps<
26+
#define GET_OP_LIST
27+
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp.inc"
28+
>();
29+
}
30+
};
31+
} // namespace
32+
33+
void mlir::transform::registerSMTExtension(DialectRegistry &dialectRegistry) {
34+
dialectRegistry.addExtensions<SMTExtension>();
35+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//===- SMTExtensionOps.cpp - SMT extension for the Transform dialect ------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
10+
#include "mlir/Dialect/SMT/IR/SMTDialect.h"
11+
#include "mlir/Dialect/Transform/IR/TransformOps.h"
12+
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
13+
14+
using namespace mlir;
15+
16+
#define GET_OP_CLASSES
17+
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp.inc"
18+
19+
//===----------------------------------------------------------------------===//
20+
// ConstrainParamsOp
21+
//===----------------------------------------------------------------------===//
22+
23+
void transform::smt::ConstrainParamsOp::getEffects(
24+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
25+
onlyReadsHandle(getParamsMutable(), effects);
26+
}
27+
28+
DiagnosedSilenceableFailure
29+
transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
30+
transform::TransformResults &results,
31+
transform::TransformState &state) {
32+
// TODO: Proper operational semantics are to check the SMT problem in the body
33+
// with a SMT solver with the arguments of the body constrained to the
34+
// values passed into the op. Success or failure is then determined by
35+
// the solver's result.
36+
// One way to support this is to just promise the TransformOpInterface
37+
// and allow for users to attach their own implementation, which would,
38+
// e.g., translate the ops to SMTLIB and hand that over to the user's
39+
// favourite solver. This requires changes to the dialect's verifier.
40+
return emitDefiniteFailure() << "op does not have interpreted semantics yet";
41+
}
42+
43+
LogicalResult transform::smt::ConstrainParamsOp::verify() {
44+
if (getOperands().size() != getBody().getNumArguments())
45+
return emitOpError(
46+
"must have the same number of block arguments as operands");
47+
48+
for (auto &op : getBody().getOps()) {
49+
if (!isa<mlir::smt::SMTDialect>(op.getDialect()))
50+
return emitOpError(
51+
"ops contained in region should belong to SMT-dialect");
52+
}
53+
54+
return success();
55+
}

0 commit comments

Comments
 (0)