Skip to content

Commit 4572127

Browse files
committed
[NFC] Switch to new pass generation tablegen definitions.
This commit completes the migration from the deprecated GEN_PASS_CLASSES to the new GEN_PASS_DEF infrastructure across all torch-mlir passes. Changes include: 1. Remove PassDetail.h files (deprecated pattern) - Deleted lib/Conversion/PassDetail.h - Deleted lib/RefBackend/PassDetail.h - Deleted lib/Dialect/Torch/Transforms/PassDetail.h - Deleted lib/Dialect/TorchConversion/Transforms/PassDetail.h - Deleted lib/Dialect/TMTensor/Transforms/PassDetail.h 2. Migrate conversion passes to GEN_PASS_DEF - Updated all passes in lib/Conversion/ to use #define GEN_PASS_DEF_* - Removed GEN_PASS_DECL from .cpp files (move to headers where needed) - Fixed includes and namespace declarations 3. Migrate dialect transform passes - Updated Torch, TorchConversion, and TMTensor transform passes - Properly scoped GEN_PASS_DEF in namespace blocks 4. Handle passes with options (TorchToStablehlo, TorchToTosa) - Added GEN_PASS_DECL_* to headers - Implemented default and convenience create functions - Used generated constructors via `using BaseClass::BaseClass` 5. Handle passes without options (RefBackend) - Removed manual create function implementations - Let tablegen auto-generate create functions - Added using declarations for Base classes in impl namespace 6. Fix backend type conversion passes - Added missing create functions in BackendTypeConversionPasses.cpp - Fixed namespace scoping issues 7. Fix missing namespace closures - Added proper closing namespace comments in Verify*BackendContract.cpp The migration maintains full backward compatibility while adopting the recommended LLVM pass infrastructure patterns. All passes now use the generated base classes and follow consistent patterns based on whether they have options defined in tablegen. Signed-off-by: hanhanW <[email protected]>
1 parent 8d563af commit 4572127

File tree

57 files changed

+492
-440
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+492
-440
lines changed

include/torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h

Lines changed: 0 additions & 27 deletions
This file was deleted.

include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,18 @@
1616

1717
namespace mlir {
1818
namespace torch {
19+
20+
#define GEN_PASS_DECL_CONVERTTORCHTOSTABLEHLO
21+
#include "torch-mlir/Conversion/Passes.h.inc"
22+
1923
std::unique_ptr<OperationPass<func::FuncOp>>
2024
createConvertTorchToStablehloPass();
25+
26+
// Convenience wrapper for users who want to pass options as individual
27+
// parameters
2128
std::unique_ptr<OperationPass<func::FuncOp>>
2229
createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index);
30+
2331
} // namespace torch
2432
} // namespace mlir
2533

include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
namespace mlir {
2020
namespace torch {
2121

22+
#define GEN_PASS_DECL_CONVERTTORCHTOTOSA
23+
#include "torch-mlir/Conversion/Passes.h.inc"
24+
2225
/// Collect a set of legal/illegal ops for converting Torch operations to Tosa
2326
/// dialect.
2427
void populateTorchToTosaConversionLegalOps(ConversionTarget &target);
@@ -30,8 +33,12 @@ populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter,
3033
RewritePatternSet &patterns);
3134

3235
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
36+
37+
// Convenience wrapper for users who want to pass options as individual
38+
// parameters
3339
std::unique_ptr<OperationPass<func::FuncOp>>
3440
createConvertTorchToTosaPass(bool requireFullTosaConversion);
41+
3542
} // namespace torch
3643
} // namespace mlir
3744

include/torch-mlir/RefBackend/Passes.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,12 @@
1515
#include "mlir/Pass/PassManager.h"
1616

1717
namespace mlir {
18-
class ModuleOp;
19-
2018
namespace torch {
2119
namespace RefBackend {
2220

2321
/// Registers all RefBackend passes.
2422
void registerRefBackendPasses();
2523

26-
std::unique_ptr<OperationPass<ModuleOp>> createMungeCallingConventionsPass();
27-
28-
std::unique_ptr<OperationPass<func::FuncOp>> createExpandOpsForLLVMPass();
29-
30-
std::unique_ptr<OperationPass<ModuleOp>> createMLProgramBufferizePass();
31-
32-
std::unique_ptr<OperationPass<func::FuncOp>> createMungeMemrefCopyPass();
33-
34-
std::unique_ptr<OperationPass<func::FuncOp>> createGeneralizeTensorConcatPass();
35-
36-
std::unique_ptr<OperationPass<func::FuncOp>> createGeneralizeTensorPadPass();
3724
} // namespace RefBackend
3825
} // namespace torch
3926
} // namespace mlir

include/torch-mlir/RefBackend/Passes.td

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,35 +14,29 @@ include "mlir/Pass/PassBase.td"
1414

1515
def MungeCallingConventions : Pass<"refback-munge-calling-conventions", "ModuleOp"> {
1616
let summary = "Munge calling conventions for calling via ExecutionEngine";
17-
let constructor = "mlir::torch::RefBackend::createMungeCallingConventionsPass();";
1817
let dependentDialects = ["memref::MemRefDialect"];
1918
}
2019

2120
def MLProgramBufferize: Pass<"refback-mlprogram-bufferize", "ModuleOp"> {
2221
let summary = "Bufferize the MLProgram dialect ops";
23-
let constructor = "mlir::torch::RefBackend::createMLProgramBufferizePass();";
2422
let dependentDialects = ["memref::MemRefDialect"];
2523
}
2624

2725
def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "func::FuncOp"> {
2826
let summary = "Expand ops into more primitive ops before LLVM lowering.";
29-
let constructor = "mlir::torch::RefBackend::createExpandOpsForLLVMPass();";
3027
}
3128

3229
def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "func::FuncOp"> {
3330
let summary = "Munge memref.copy to linalg.copy";
34-
let constructor = "mlir::torch::RefBackend::createMungeMemrefCopyPass();";
3531
let dependentDialects = ["memref::MemRefDialect"];
3632
}
3733

3834
def GeneralizeTensorConcat : Pass<"refback-generalize-tensor-concat", "func::FuncOp"> {
3935
let summary = "Convert tensor.concat to other tensor ops";
40-
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorConcatPass()";
4136
}
4237

4338
def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "func::FuncOp"> {
4439
let summary = "Convert tensor.pad to linalg ops";
45-
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()";
4640
}
4741

4842
#endif // TORCHMLIR_REFBACKEND_PASSES

lib/Conversion/PassDetail.h

Lines changed: 0 additions & 26 deletions
This file was deleted.

lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
//===----------------------------------------------------------------------===//
99

1010
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
11+
#include "mlir/Dialect/Func/IR/FuncOps.h"
12+
#include "mlir/Pass/Pass.h"
13+
#include "torch-mlir/Conversion/Passes.h"
1114

12-
#include "../PassDetail.h"
1315
#include "mlir/Dialect/Arith/IR/Arith.h"
1416
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
1517
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -20,6 +22,10 @@ using namespace mlir;
2022
using namespace mlir::torch;
2123
using namespace mlir::torch::Torch;
2224
using namespace mlir::torch::TorchConversion;
25+
namespace mlir::torch {
26+
27+
#define GEN_PASS_DEF_CONVERTTORCHCONVERSIONTOMLPROGRAM
28+
#include "torch-mlir/Conversion/Passes.h.inc"
2329

2430
static constexpr StringRef getSeedGobalVarName() { return "global_seed"; }
2531

@@ -102,7 +108,7 @@ class ConvertGetNextSeedOp : public OpConversionPattern<GetNextSeedOp> {
102108

103109
namespace {
104110
class ConvertTorchConversionToMLProgram
105-
: public ConvertTorchConversionToMLProgramBase<
111+
: public impl::ConvertTorchConversionToMLProgramBase<
106112
ConvertTorchConversionToMLProgram> {
107113
public:
108114
void getDependentDialects(DialectRegistry &registry) const override {
@@ -138,6 +144,8 @@ class ConvertTorchConversionToMLProgram
138144
} // namespace
139145

140146
std::unique_ptr<OperationPass<ModuleOp>>
141-
mlir::torch::createConvertTorchConversionToMLProgramPass() {
147+
createConvertTorchConversionToMLProgramPass() {
142148
return std::make_unique<ConvertTorchConversionToMLProgram>();
143149
}
150+
151+
} // namespace mlir::torch

lib/Conversion/TorchOnnxToTorch/PassDetail.h

Lines changed: 0 additions & 24 deletions
This file was deleted.

lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
//
88
//===----------------------------------------------------------------------===//
99

10-
#include "./PassDetail.h"
10+
#include "mlir/Dialect/Func/IR/FuncOps.h"
11+
#include "mlir/Pass/Pass.h"
1112
#include "mlir/Support/LLVM.h"
1213
#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h"
1314
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
@@ -19,6 +20,10 @@ using llvm::dbgs;
1920
using namespace mlir;
2021
using namespace mlir::torch;
2122
using namespace mlir::torch::onnx_c;
23+
namespace mlir::torch::onnx_c {
24+
25+
#define GEN_PASS_DEF_CONVERTTORCHONNXTOTORCH
26+
#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h.inc"
2227

2328
#define DEBUG_TYPE "torch-onnx"
2429

@@ -37,7 +42,7 @@ int64_t getDefaultOpsetVersion(Operation *containerOp) {
3742
}
3843

3944
class ConvertTorchOnnxToTorch
40-
: public ConvertTorchOnnxToTorchBase<ConvertTorchOnnxToTorch> {
45+
: public impl::ConvertTorchOnnxToTorchBase<ConvertTorchOnnxToTorch> {
4146
public:
4247
ConvertTorchOnnxToTorch() = default;
4348
void runOnOperation() override {
@@ -82,7 +87,8 @@ class ConvertTorchOnnxToTorch
8287

8388
} // namespace
8489

85-
std::unique_ptr<OperationPass<func::FuncOp>>
86-
mlir::torch::onnx_c::createTorchOnnxToTorchPass() {
90+
std::unique_ptr<OperationPass<func::FuncOp>> createTorchOnnxToTorchPass() {
8791
return std::make_unique<ConvertTorchOnnxToTorch>();
8892
}
93+
94+
} // namespace mlir::torch::onnx_c

lib/Conversion/TorchToArith/TorchToArith.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
//===----------------------------------------------------------------------===//
99

1010
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
11+
#include "mlir/Dialect/Func/IR/FuncOps.h"
12+
#include "mlir/Pass/Pass.h"
13+
#include "torch-mlir/Conversion/Passes.h"
1114

12-
#include "../PassDetail.h"
1315
#include "mlir/Dialect/Arith/IR/Arith.h"
1416
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1517
#include "mlir/Dialect/Math/IR/Math.h"
@@ -25,6 +27,10 @@
2527
using namespace mlir;
2628
using namespace mlir::torch;
2729
using namespace mlir::torch::Torch;
30+
namespace mlir::torch {
31+
32+
#define GEN_PASS_DEF_CONVERTTORCHTOARITH
33+
#include "torch-mlir/Conversion/Passes.h.inc"
2834

2935
// -----------------------------------------------------------------------------
3036
// Patterns (as this grows, it should be organized into multiple files)
@@ -407,7 +413,7 @@ class ConvertAtenBoolLikeOp : public OpConversionPattern<OpTy> {
407413

408414
namespace {
409415
class ConvertTorchToArith
410-
: public ConvertTorchToArithBase<ConvertTorchToArith> {
416+
: public impl::ConvertTorchToArithBase<ConvertTorchToArith> {
411417
public:
412418
void getDependentDialects(DialectRegistry &registry) const override {
413419
registry.insert<func::FuncDialect>();
@@ -565,7 +571,8 @@ class ConvertTorchToArith
565571
};
566572
} // namespace
567573

568-
std::unique_ptr<OperationPass<func::FuncOp>>
569-
mlir::torch::createConvertTorchToArithPass() {
574+
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToArithPass() {
570575
return std::make_unique<ConvertTorchToArith>();
571576
}
577+
578+
} // namespace mlir::torch

0 commit comments

Comments
 (0)