Skip to content

Commit

Permalink
[onnx] Add torch-mlir-import-onnx tool. (#2637)
Browse files Browse the repository at this point in the history
Simple Python console script to import an ONNX protobuf to the torch
dialect for additional processing.

For installed wheels, this can be used with something like:

```
torch-mlir-import-onnx test/python/onnx_importer/LeakyReLU.onnx
```

Or from a dev setup:

```
python -m torch_mlir.tools.import_onnx ...
```
  • Loading branch information
stellaraccident committed Dec 13, 2023
1 parent 7cf52ae commit ed4df38
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 2 deletions.
7 changes: 7 additions & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ declare_mlir_python_sources(TorchMLIRPythonSources.Importers
extras/onnx_importer.py
)

declare_mlir_python_sources(TorchMLIRPythonSources.Tools
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES
tools/import_onnx/__main__.py
)

################################################################################
# Extensions
################################################################################
Expand Down
77 changes: 77 additions & 0 deletions python/torch_mlir/tools/import_onnx/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

"""Console tool for converting an ONNX proto to torch IR.
Typically, when installed from a wheel, this can be invoked as:
torch-mlir-import-onnx some.pb
Or from Python:
python -m torch_mlir.tools.import_onnx ...
"""
import argparse
from pathlib import Path
import sys

import onnx

from ...extras import onnx_importer

from ...dialects import torch as torch_d
from ...ir import (
Context,
)


def main(args):
model_proto = load_onnx_model(args.input_file)
context = Context()
torch_d.register_dialect(context)
model_info = onnx_importer.ModelInfo(model_proto)
m = model_info.create_module(context=context)
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
imp.import_all()
if not args.no_verify:
m.verify()

# TODO: This isn't very efficient output. If these files ever
# get large, enable bytecode and direct binary emission to save
# some copies.
if args.output_file and args.output_file != "-":
with open(args.output_file, "wt") as f:
print(m.get_asm(assume_verified=not args.no_verify), file=f)
else:
print(m.get_asm(assume_verified=not args.no_verify))


def load_onnx_model(file_path: Path) -> onnx.ModelProto:
raw_model = onnx.load(file_path)
inferred_model = onnx.shape_inference.infer_shapes(raw_model)
return inferred_model


def parse_arguments(argv=None):
parser = argparse.ArgumentParser(description="Torch-mlir ONNX import tool")
parser.add_argument("input_file", help="ONNX protobuf input", type=Path)
parser.add_argument(
"-o", dest="output_file", help="Output path (or '-' for stdout)"
)
parser.add_argument(
"--no-verify",
action="store_true",
help="Disable verification prior to printing",
)
args = parser.parse_args(argv)
return args


def _cli_main():
sys.exit(main(parse_arguments()))


if __name__ == "__main__":
_cli_main()
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ def build_extension(self, ext):
"onnx": [
"onnx>=1.15.0",
],
}
},
entry_points={
"console_scripts": [
"torch-mlir-import-onnx = torch_mlir.tools.import_onnx:_cli_main",
],
},
zip_safe=False,
)
2 changes: 1 addition & 1 deletion test/lit.cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)

# suffixes: A list of file extensions to treat as test files.
config.suffixes = ['.mlir', '.py']
config.suffixes = ['.mlir', '.py', '.runlit']

# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
Expand Down
15 changes: 15 additions & 0 deletions test/python/onnx_importer/LeakyReLU.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
pytorch0.3:h
"
01" LeakyRelu*
alpha
�#<�torch-jit-exportZ
0



b
1



B
3 changes: 3 additions & 0 deletions test/python/onnx_importer/import_onnx_tool.runlit
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# RUN: %PYTHON -m torch_mlir.tools.import_onnx %S/LeakyReLU.onnx | FileCheck %s

# CHECK: torch.operator "onnx.LeakyRelu"

0 comments on commit ed4df38

Please sign in to comment.