Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,18 @@

namespace mlir {
namespace torch {

#define GEN_PASS_DECL_CONVERTTORCHTOSTABLEHLO
#include "torch-mlir/Conversion/Passes.h.inc"

std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToStablehloPass();

// Convenience wrapper for users who want to pass options as individual
// parameters
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index);

} // namespace torch
} // namespace mlir

Expand Down
7 changes: 7 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
namespace mlir {
namespace torch {

#define GEN_PASS_DECL_CONVERTTORCHTOTOSA
#include "torch-mlir/Conversion/Passes.h.inc"

/// Collect a set of legal/illegal ops for converting Torch operations to Tosa
/// dialect.
void populateTorchToTosaConversionLegalOps(ConversionTarget &target);
Expand All @@ -30,8 +33,12 @@ populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter,
RewritePatternSet &patterns);

std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();

// Convenience wrapper for users who want to pass options as individual
// parameters
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToTosaPass(bool requireFullTosaConversion);

} // namespace torch
} // namespace mlir

Expand Down
13 changes: 0 additions & 13 deletions include/torch-mlir/RefBackend/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,12 @@
#include "mlir/Pass/PassManager.h"

namespace mlir {
class ModuleOp;

namespace torch {
namespace RefBackend {

/// Registers all RefBackend passes.
void registerRefBackendPasses();

std::unique_ptr<OperationPass<ModuleOp>> createMungeCallingConventionsPass();

std::unique_ptr<OperationPass<func::FuncOp>> createExpandOpsForLLVMPass();

std::unique_ptr<OperationPass<ModuleOp>> createMLProgramBufferizePass();

std::unique_ptr<OperationPass<func::FuncOp>> createMungeMemrefCopyPass();

std::unique_ptr<OperationPass<func::FuncOp>> createGeneralizeTensorConcatPass();

std::unique_ptr<OperationPass<func::FuncOp>> createGeneralizeTensorPadPass();
} // namespace RefBackend
} // namespace torch
} // namespace mlir
Expand Down
6 changes: 0 additions & 6 deletions include/torch-mlir/RefBackend/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,29 @@ include "mlir/Pass/PassBase.td"

def MungeCallingConventions : Pass<"refback-munge-calling-conventions", "ModuleOp"> {
let summary = "Munge calling conventions for calling via ExecutionEngine";
let constructor = "mlir::torch::RefBackend::createMungeCallingConventionsPass();";
let dependentDialects = ["memref::MemRefDialect"];
}

def MLProgramBufferize: Pass<"refback-mlprogram-bufferize", "ModuleOp"> {
let summary = "Bufferize the MLProgram dialect ops";
let constructor = "mlir::torch::RefBackend::createMLProgramBufferizePass();";
let dependentDialects = ["memref::MemRefDialect"];
}

def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "func::FuncOp"> {
let summary = "Expand ops into more primitive ops before LLVM lowering.";
let constructor = "mlir::torch::RefBackend::createExpandOpsForLLVMPass();";
}

def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "func::FuncOp"> {
let summary = "Munge memref.copy to linalg.copy";
let constructor = "mlir::torch::RefBackend::createMungeMemrefCopyPass();";
let dependentDialects = ["memref::MemRefDialect"];
}

def GeneralizeTensorConcat : Pass<"refback-generalize-tensor-concat", "func::FuncOp"> {
let summary = "Convert tensor.concat to other tensor ops";
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorConcatPass()";
}

def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "func::FuncOp"> {
let summary = "Convert tensor.pad to linalg ops";
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()";
}

#endif // TORCHMLIR_REFBACKEND_PASSES
26 changes: 0 additions & 26 deletions lib/Conversion/PassDetail.h

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "torch-mlir/Conversion/Passes.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand All @@ -20,6 +22,10 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;
namespace mlir::torch {

#define GEN_PASS_DEF_CONVERTTORCHCONVERSIONTOMLPROGRAM
#include "torch-mlir/Conversion/Passes.h.inc"

static constexpr StringRef getSeedGobalVarName() { return "global_seed"; }

Expand Down Expand Up @@ -102,7 +108,7 @@ class ConvertGetNextSeedOp : public OpConversionPattern<GetNextSeedOp> {

namespace {
class ConvertTorchConversionToMLProgram
: public ConvertTorchConversionToMLProgramBase<
: public impl::ConvertTorchConversionToMLProgramBase<
ConvertTorchConversionToMLProgram> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
Expand Down Expand Up @@ -138,6 +144,8 @@ class ConvertTorchConversionToMLProgram
} // namespace

std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::createConvertTorchConversionToMLProgramPass() {
createConvertTorchConversionToMLProgramPass() {
return std::make_unique<ConvertTorchConversionToMLProgram>();
}

} // namespace mlir::torch
24 changes: 0 additions & 24 deletions lib/Conversion/TorchOnnxToTorch/PassDetail.h

This file was deleted.

14 changes: 10 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
//
//===----------------------------------------------------------------------===//

#include "./PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
Expand All @@ -19,6 +20,10 @@ using llvm::dbgs;
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::onnx_c;
namespace mlir::torch::onnx_c {

#define GEN_PASS_DEF_CONVERTTORCHONNXTOTORCH
#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h.inc"

#define DEBUG_TYPE "torch-onnx"

Expand All @@ -37,7 +42,7 @@ int64_t getDefaultOpsetVersion(Operation *containerOp) {
}

class ConvertTorchOnnxToTorch
: public ConvertTorchOnnxToTorchBase<ConvertTorchOnnxToTorch> {
: public impl::ConvertTorchOnnxToTorchBase<ConvertTorchOnnxToTorch> {
public:
ConvertTorchOnnxToTorch() = default;
void runOnOperation() override {
Expand Down Expand Up @@ -82,7 +87,8 @@ class ConvertTorchOnnxToTorch

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::onnx_c::createTorchOnnxToTorchPass() {
std::unique_ptr<OperationPass<func::FuncOp>> createTorchOnnxToTorchPass() {
return std::make_unique<ConvertTorchOnnxToTorch>();
}

} // namespace mlir::torch::onnx_c
15 changes: 11 additions & 4 deletions lib/Conversion/TorchToArith/TorchToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "torch-mlir/Conversion/Passes.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
Expand All @@ -25,6 +27,10 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace mlir::torch {

#define GEN_PASS_DEF_CONVERTTORCHTOARITH
#include "torch-mlir/Conversion/Passes.h.inc"

// -----------------------------------------------------------------------------
// Patterns (as this grows, it should be organized into multiple files)
Expand Down Expand Up @@ -407,7 +413,7 @@ class ConvertAtenBoolLikeOp : public OpConversionPattern<OpTy> {

namespace {
class ConvertTorchToArith
: public ConvertTorchToArithBase<ConvertTorchToArith> {
: public impl::ConvertTorchToArithBase<ConvertTorchToArith> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect>();
Expand Down Expand Up @@ -565,7 +571,8 @@ class ConvertTorchToArith
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToArithPass() {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToArithPass() {
return std::make_unique<ConvertTorchToArith>();
}

} // namespace mlir::torch
15 changes: 11 additions & 4 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "torch-mlir/Conversion/Passes.h"

#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
Expand All @@ -24,6 +26,10 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace mlir::torch {

#define GEN_PASS_DEF_CONVERTTORCHTOLINALG
#include "torch-mlir/Conversion/Passes.h.inc"

// -----------------------------------------------------------------------------
// The pass
Expand All @@ -34,7 +40,7 @@ using namespace mlir::torch::Torch;

namespace {
class ConvertTorchToLinalg
: public ConvertTorchToLinalgBase<ConvertTorchToLinalg> {
: public impl::ConvertTorchToLinalgBase<ConvertTorchToLinalg> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect>();
Expand Down Expand Up @@ -89,7 +95,8 @@ class ConvertTorchToLinalg
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToLinalgPass() {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToLinalgPass() {
return std::make_unique<ConvertTorchToLinalg>();
}

} // namespace mlir::torch
1 change: 0 additions & 1 deletion lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
Expand Down
Loading
Loading