Skip to content

Commit e06efc5

Browse files
Initial TorchOnnxToTorch conversion pipeline. (#2585)
Adds a pipeline to convert custom ops and metadata represented as `torch.operator` custom ops to corresponding `torch` ops where possible. This is part of a multi-part approach for building ONNX import in as a regular feature of torch-mlir. It is focused on the conversions vs the infra. We will end up maintaining a [pure-python importer](https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/importers/onnx_importer.py) to go with this in torch-mlir, and we will also maintain test case generation utilities derived from it. I have left substantial documentation in the README of the conversion directory, including the recommended approach that we will take to keep building this out. (note that this organizes the code to coincide with the refactoring in #2442 versus the current flat arrangement)
1 parent d50d3aa commit e06efc5

File tree

19 files changed

+897
-6
lines changed

19 files changed

+897
-6
lines changed

include/torch-mlir/Conversion/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
add_subdirectory(TorchOnnxToTorch)
2+
13
set(LLVM_TARGET_DEFINITIONS Passes.td)
24
if(TORCH_MLIR_ENABLE_STABLEHLO)
35
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls)
3+
add_public_tablegen_target(TorchMLIRConversionTorchOnnxToTorchPassIncGen)
4+
add_mlir_doc(Passes TorchMLIRConversionTorchOnnxToTorchPasses ./ -gen-pass-doc)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===------------------------------------------------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// Also available under a BSD-style license. See LICENSE.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H
11+
#define TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H
12+
13+
#include "mlir/Dialect/Func/IR/FuncOps.h"
14+
#include "mlir/IR/BuiltinOps.h"
15+
#include "mlir/Pass/Pass.h"
16+
#include <memory>
17+
18+
namespace mlir::torch::onnx_c {
19+
20+
std::unique_ptr<OperationPass<func::FuncOp>> createTorchOnnxToTorchPass();
21+
22+
/// Registers all torch-mlir conversion passes.
23+
void registerTorchOnnxToTorchPasses();
24+
25+
} // namespace mlir::torch::onnx_c
26+
27+
#endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// Also available under a BSD-style license. See LICENSE.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES
11+
#define TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES
12+
13+
include "mlir/Pass/PassBase.td"
14+
15+
def ConvertTorchOnnxToTorch : Pass<"convert-torch-onnx-to-torch", "func::FuncOp"> {
16+
let summary = "Converts ONNX custom ops in the torch dialect to native torch ops";
17+
let description = [{
18+
Converts equivalent ONNX custom ops to built-in equivalents.
19+
20+
See the README for a detailed description of how this operates.
21+
}];
22+
23+
let constructor = "mlir::torch::onnx_c::createTorchOnnxToTorchPass()";
24+
}
25+
26+
#endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
//===------------------------------------------------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// Also available under a BSD-style license. See LICENSE.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef TORCHMLIR_CONVERSION_TORCHONNX_CONVERSION_UTILS_H
11+
#define TORCHMLIR_CONVERSION_TORCHONNX_CONVERSION_UTILS_H
12+
13+
#include "mlir/IR/BuiltinAttributes.h"
14+
#include "mlir/Support/LogicalResult.h"
15+
#include "mlir/Transforms/DialectConversion.h"
16+
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
17+
#include "llvm/ADT/DenseMap.h"
18+
#include "llvm/ADT/SmallString.h"
19+
#include "llvm/ADT/SmallVector.h"
20+
21+
namespace mlir::torch::onnx_c {
22+
23+
/// Used during ONNX pattern matching to bind common patterns of operands,
24+
/// result types and attributes to local variables in a way that is easy
25+
/// to fail the pattern if constraints are violated. Most methods return
26+
/// a ParseResult, which allows for chaining like:
27+
///
28+
/// if (binder.tensorOperand(foo) || binder.tensorResultType(t))
29+
/// return failure();
30+
struct OpBinder {
31+
OpBinder(Operation *op) : op(op) {}
32+
33+
Location getLoc() { return op->getLoc(); }
34+
35+
// Operand matches of different arities.
36+
ParseResult tensorOperand(Value &value0) {
37+
if (op->getNumOperands() != 1)
38+
return failure();
39+
value0 = op->getOperand(0);
40+
if (!toValidTensorType(value0.getType()))
41+
return failure();
42+
return success();
43+
}
44+
45+
ParseResult tensorOperands(Value &value0, Value &value1) {
46+
if (op->getNumOperands() != 2)
47+
return failure();
48+
value0 = op->getOperand(0);
49+
value1 = op->getOperand(1);
50+
if (!toValidTensorType(value0.getType()) ||
51+
!toValidTensorType(value1.getType()))
52+
return failure();
53+
return success();
54+
}
55+
56+
// Result type matchers of different arities.
57+
ParseResult tensorResultType(Torch::ValueTensorType &type0) {
58+
if (op->getNumResults() != 1)
59+
return failure();
60+
auto t = toValidTensorType(op->getResult(0).getType());
61+
if (!t)
62+
return failure();
63+
type0 = t;
64+
return success();
65+
}
66+
67+
// Attribute accessors.
68+
ParseResult s64BoolAttr(bool &value, StringRef nameSuffix,
69+
bool defaultValue = false) {
70+
SmallString<64> name("torch.onnx.");
71+
name.append(nameSuffix);
72+
auto attr = op->getAttr(name);
73+
if (!attr) {
74+
value = defaultValue;
75+
return success();
76+
}
77+
if (auto integerAttr = dyn_cast<IntegerAttr>(attr)) {
78+
IntegerType t = cast<IntegerType>(integerAttr.getType());
79+
if (!t.isSigned() || t.getWidth() != 64)
80+
return failure();
81+
value = static_cast<bool>(integerAttr.getSInt());
82+
return success();
83+
}
84+
return failure();
85+
}
86+
87+
ParseResult s64IntegerAttr(int64_t &value, StringRef nameSuffix,
88+
int64_t defaultValue = 0) {
89+
SmallString<64> name("torch.onnx.");
90+
name.append(nameSuffix);
91+
auto attr = op->getAttr(name);
92+
if (!attr) {
93+
value = defaultValue;
94+
return success();
95+
}
96+
if (auto integerAttr = dyn_cast<IntegerAttr>(attr)) {
97+
IntegerType t = cast<IntegerType>(integerAttr.getType());
98+
if (!t.isSigned() || t.getWidth() != 64)
99+
return failure();
100+
value = integerAttr.getSInt();
101+
return success();
102+
}
103+
return failure();
104+
}
105+
106+
Torch::ValueTensorType toValidTensorType(Type t) {
107+
auto tt = dyn_cast<Torch::ValueTensorType>(t);
108+
if (tt && tt.hasSizes())
109+
return tt;
110+
return {};
111+
}
112+
113+
Operation *op;
114+
};
115+
116+
/// We use a single pattern per ONNX domain to handle all named custom
117+
/// ops.
118+
/// This allows us to avoid the n^2 problem on pattern application by
119+
/// implementing a secondary index based on the name and sinceVersion
120+
/// attributes.
121+
/// It also lets us add some ergonomics for trivial cases.
122+
class OnnxCustomOpConversionPattern
123+
: public OpConversionPattern<Torch::OperatorOp> {
124+
public:
125+
using HandlerFn = LogicalResult (*)(OpBinder binder,
126+
ConversionPatternRewriter &rewriter);
127+
struct HandlerReg {
128+
HandlerReg(HandlerFn callback, int64_t sinceVersion)
129+
: callback(callback), sinceVersion(sinceVersion) {}
130+
HandlerFn callback;
131+
int64_t sinceVersion;
132+
};
133+
134+
OnnxCustomOpConversionPattern(MLIRContext *context, std::string domainPrefix,
135+
int64_t domainVersion)
136+
: OpConversionPattern(context), domainPrefix(std::move(domainPrefix)),
137+
domainVersion(domainVersion) {}
138+
139+
LogicalResult
140+
matchAndRewrite(Torch::OperatorOp op, OpAdaptor adaptor,
141+
ConversionPatternRewriter &rewriter) const override;
142+
143+
/// Adds all fully qualified operator names to the given set.
144+
/// This is typically used for implementing a dynamic legality
145+
/// check for torch.operator names.
146+
void populateLegalizedNames(DenseSet<StringAttr> &legalizedNames);
147+
148+
/// Register a conversion for a specific ONNX operator. For the
149+
/// default domain, this is the canonical ONNX operator name (i.e.
150+
/// "Acos").
151+
/// Multiple conversions can be registered for the same op, most
152+
/// commonly differing by their `sinceVersion`.
153+
void onOp(StringRef name, int64_t sinceVersion, HandlerFn callback);
154+
155+
private:
156+
std::string domainPrefix;
157+
int64_t domainVersion;
158+
DenseMap<StringAttr, SmallVector<HandlerReg, 1>> namedHandlers;
159+
};
160+
161+
// Patterns are split into chunks to speed compile time and reduce some
162+
// contention on the same source files.
163+
void populateDefaultDomainAtoF(OnnxCustomOpConversionPattern &patterns);
164+
void populateDefaultDomainGtoP(OnnxCustomOpConversionPattern &patterns);
165+
void populateDefaultDomainQtoZ(OnnxCustomOpConversionPattern &patterns);
166+
167+
} // namespace mlir::torch::onnx_c
168+
169+
#endif // TORCHMLIR_CONVERSION_TORCHONNX_CONVERSION_UTILS_H
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# TorchOnnx To Torch Conversions
2+
3+
We enable the direct representation of many ONNX features directly in
4+
the `torch` dialect as `torch.operator` custom ops with names like
5+
`onnx.{OperatorName}`. The majority of ONNX operators are represented
6+
with a systematic transformation. See
7+
[onnx_importer.py](https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/importers/onnx_importer.py)
8+
for the reference importer which complies with the rules below
9+
(this is planned to be upstreamed to torch-mlir proper in the near
10+
future).
11+
12+
## Adding new ONNX operators
13+
14+
With the exception of certain special or complicated ONNX operators, most
15+
are relatively straight-forward to map, following this general procedure:
16+
17+
* Plan the ops you wish to support by consulting the
18+
[ONNX operator database](https://onnx.ai/onnx/operators/).
19+
* This database has detailed diffs wrt different support versions but
20+
at the level of detail we operate, most version diffs are inconsequential
21+
and just require a bit more pattern support.
22+
* This typically applies to generalization of broadcasting semantics,
23+
expanded type support, and other things of the like.
24+
* *Prerequisite*: Add support for the op to torch-mlir if it does not
25+
already exist.
26+
* Open the corresponding implementation file `DefaultDomainXtoY.cpp`
27+
corresponding with the alphabetic sort of the op and add a conversion.
28+
* Generate successful test cases:
29+
* Either run the Turbine importer to produce MLIR output for all
30+
ops/models in the ONNX test suite or use a dump that someone has
31+
generated:
32+
* [2023-Nov-21](https://drive.google.com/file/d/1P6QaRXGnCeApjdjNmykLxWa-yqMmIO-d/view?usp=sharing)
33+
* There are often many variants of tests for checking conformance of
34+
different historic ONNX encodings, but these are often not load bearing
35+
at the MLIR level.
36+
* Pick a handful of test cases and add them to
37+
`test/Conversion/TorchOnnxToTorch/simple_ops_x_to_y.mlir` corresponding to an
38+
alphabetic breakdown. At this time, ignore tests that are not exercising
39+
useful differences in the pattern implementations.
40+
* Generate failure test cases:
41+
* Some ops have forms that do not (easily) map to torch-mlir. If you leave
42+
an op under-implemented, add a failing test case to
43+
`test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir`.
44+
* Optional but recommended: Use your test case files to fuzz against the
45+
torch-mlir backend of your choice by running a backend conversion pipeline
46+
and fixing any crashes/issues.
47+
* Send a patch with your changes.
48+
49+
## ONNX proto to `torch` dialect mapping
50+
51+
### Type Conversion
52+
53+
* Tensors: ONNX tensor types are converted to `torch.vtensor`
54+
with static and dynamic dimensions. We require that shape
55+
inference has run to produce ranked tensors.
56+
* Tensor element types are directly converted to corresponding
57+
MLIR types as used by the rest of torch-mlir.
58+
* String, sequence and sparse tensor types are presently not mapped.
59+
60+
### Attributes
61+
62+
A subset of attributes types are converted directly to an attribute
63+
dict on the op with a name like `torch.onnx.{AttributeName}`. The
64+
following attribute type mappings are made:
65+
66+
* `FLOAT`: `FloatAttr`
67+
* `INT`: Signed `IntegerAttr` of width 64
68+
* `STRING`: `StringAttr`
69+
* `TENSOR`: Converted to one of:
70+
* `DenseResourceElementsAttr` for inlined `raw_data`
71+
* `DenseElementsAttr` for splats
72+
* `DenseElementsAttr` for inlined typed proto initialization
73+
* `FLOATS`: `ArrayAttr` of `FloatAttr`
74+
* `INTS`: `ArrayAttr` of signed `IntegerAttr` of width 64
75+
* `STRINGS`: `ArrayAttr` of `StringAttr`
76+
* `TENSORS`: `ArrayAttr` of corresponding `TENSOR` conversion
77+
78+
The following attribute types have no present, systematic conversion.
79+
Their presence on an op indicates that the op is a special form, which
80+
must be handled specially:
81+
82+
* `GRAPH`
83+
* `SPARSE_TENSOR` (TBD: it is possible to handle this systematically if
84+
useful).
85+
* `TYPE_PROTO` (TBD: it may be possible to handle this systematically if
86+
useful).
87+
* Plural equivalents of the above.
88+
89+
### Default operation conversion
90+
91+
Operations are converted to a `torch.operator` with name `onnx.{OperatorName}`.
92+
The constraint that the ONNX graph is topologically sorted and free of
93+
cycles matches the SSA form. Operands and results are mapped directly.
94+
95+
This conversion only applies to the default (empty) domain.
96+
97+
### Quantization information
98+
99+
Quantization parameters are carried out of line in the ONNX protobuf
100+
and will be repatriated upon import to torch. The exact mechanism is
101+
not yet implemented.
102+
103+
### Version and metadata
104+
105+
The `IsolatedFromAbove` parent of the ops can contain the following
106+
metadata:
107+
108+
* `torch.onnx_meta.ir_version`: 64bit `IntegerAttr` corresponding to
109+
`ModelProto.ir_version`.
110+
* `torch.onnx_meta.producer_name`: `StringAttr` corresponding to
111+
`ModelProto.producer_name`.
112+
* `torch.onnx_meta.producer_version`: `StringAttr` corresponding to
113+
`ModelProto.producer_version`.
114+
* `torch.onnx_meta.opset_version`: 64bit `IntegerAttr` corresponding
115+
to `ModelProto.opset_import.version` for the domain "" (empty).
116+
Will be ommitted if the default opset is not included.
117+
* `torch.onnx_meta.opset_versions`: DictAttr of 64bit `IntegerAttr`
118+
for each non default domain.
119+
120+
Generally, the importer handles variations in `ir_version` whereas
121+
the transformations here handle opset version differences. Version
122+
independent transformations are encouraged where possible if there
123+
are only minor variations of an op. Major variations should use
124+
`since_version` sensitive patterns.
125+
126+
### Special op forms
127+
128+
Certain ONNX operators map to different structural components of
129+
torch-mlir's representation:
130+
131+
* `ConstantOfShape`: Mapped to `torch.vtensor.literal` with
132+
a corresponding `value` attribute.
133+

lib/CMakeLists.txt

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,19 @@ set(LinkedLibs
1414
MLIRTosaDialect
1515
MLIRSupport
1616

17-
TorchMLIRTorchPasses
18-
TorchMLIRTorchConversionDialect
19-
17+
# Dialects.
18+
TorchMLIRTMTensorDialect
2019
TorchMLIRTorchDialect
21-
TorchMLIRTorchConversionPasses
20+
TorchMLIRTorchConversionDialect
2221

22+
# Dialect passes.
2323
TorchMLIRTMTensorPasses
24-
TorchMLIRTMTensorDialect
24+
TorchMLIRTorchConversionPasses
25+
TorchMLIRTorchPasses
2526

27+
# Conversion passes.
2628
TorchMLIRConversionPasses
29+
TorchMLIRTorchOnnxToTorch
2730
)
2831

2932
if(TORCH_MLIR_ENABLE_REFBACKEND)

lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
add_subdirectory(TorchOnnxToTorch)
12
add_subdirectory(TorchToLinalg)
23
add_subdirectory(TorchToSCF)
34
add_subdirectory(TorchToArith)

0 commit comments

Comments
 (0)