From ed4df38e8d86083c4dcc1b58f7d59a4c8cf6ab85 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 12 Dec 2023 22:01:30 -0800 Subject: [PATCH] [onnx] Add torch-mlir-import-onnx tool. (#2637) Simple Python console script to import an ONNX protobuf to the torch dialect for additional processing. For installed wheels, this can be used with something like: ``` torch-mlir-import-onnx test/python/onnx_importer/LeakyReLU.onnx ``` Or from a dev setup: ``` python -m torch_mlir.tools.import_onnx ... ``` --- python/CMakeLists.txt | 7 ++ .../torch_mlir/tools/import_onnx/__main__.py | 77 +++++++++++++++++++ setup.py | 7 +- test/lit.cfg.py | 2 +- test/python/onnx_importer/LeakyReLU.onnx | 15 ++++ .../onnx_importer/import_onnx_tool.runlit | 3 + 6 files changed, 109 insertions(+), 2 deletions(-) create mode 100644 python/torch_mlir/tools/import_onnx/__main__.py create mode 100644 test/python/onnx_importer/LeakyReLU.onnx create mode 100644 test/python/onnx_importer/import_onnx_tool.runlit diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 7b9bf12f2b8f..f29429b7246c 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -38,6 +38,13 @@ declare_mlir_python_sources(TorchMLIRPythonSources.Importers extras/onnx_importer.py ) +declare_mlir_python_sources(TorchMLIRPythonSources.Tools + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources + SOURCES + tools/import_onnx/__main__.py +) + ################################################################################ # Extensions ################################################################################ diff --git a/python/torch_mlir/tools/import_onnx/__main__.py b/python/torch_mlir/tools/import_onnx/__main__.py new file mode 100644 index 000000000000..b300b4100b3e --- /dev/null +++ b/python/torch_mlir/tools/import_onnx/__main__.py @@ -0,0 +1,77 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +"""Console tool for converting an ONNX proto to torch IR. + +Typically, when installed from a wheel, this can be invoked as: + + torch-mlir-import-onnx some.pb + +Or from Python: + + python -m torch_mlir.tools.import_onnx ... +""" +import argparse +from pathlib import Path +import sys + +import onnx + +from ...extras import onnx_importer + +from ...dialects import torch as torch_d +from ...ir import ( + Context, +) + + +def main(args): + model_proto = load_onnx_model(args.input_file) + context = Context() + torch_d.register_dialect(context) + model_info = onnx_importer.ModelInfo(model_proto) + m = model_info.create_module(context=context) + imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m) + imp.import_all() + if not args.no_verify: + m.verify() + + # TODO: This isn't very efficient output. If these files ever + # get large, enable bytecode and direct binary emission to save + # some copies. + if args.output_file and args.output_file != "-": + with open(args.output_file, "wt") as f: + print(m.get_asm(assume_verified=not args.no_verify), file=f) + else: + print(m.get_asm(assume_verified=not args.no_verify)) + + +def load_onnx_model(file_path: Path) -> onnx.ModelProto: + raw_model = onnx.load(file_path) + inferred_model = onnx.shape_inference.infer_shapes(raw_model) + return inferred_model + + +def parse_arguments(argv=None): + parser = argparse.ArgumentParser(description="Torch-mlir ONNX import tool") + parser.add_argument("input_file", help="ONNX protobuf input", type=Path) + parser.add_argument( + "-o", dest="output_file", help="Output path (or '-' for stdout)" + ) + parser.add_argument( + "--no-verify", + action="store_true", + help="Disable verification prior to printing", + ) + args = parser.parse_args(argv) + return args + + +def _cli_main(): + sys.exit(main(parse_arguments())) + + +if __name__ == "__main__": + _cli_main() diff --git a/setup.py b/setup.py index a4b42309d755..77c8b2ad047d 100644 --- a/setup.py +++ b/setup.py @@ -186,6 +186,11 @@ def build_extension(self, ext): "onnx": [ "onnx>=1.15.0", ], - } + }, + entry_points={ + "console_scripts": [ + "torch-mlir-import-onnx = torch_mlir.tools.import_onnx:_cli_main", + ], + }, zip_safe=False, ) diff --git a/test/lit.cfg.py b/test/lit.cfg.py index a9753bf22719..4608dfb6c892 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -24,7 +24,7 @@ config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) # suffixes: A list of file extensions to treat as test files. -config.suffixes = ['.mlir', '.py'] +config.suffixes = ['.mlir', '.py', '.runlit'] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) diff --git a/test/python/onnx_importer/LeakyReLU.onnx b/test/python/onnx_importer/LeakyReLU.onnx new file mode 100644 index 000000000000..f76bccbce92a --- /dev/null +++ b/test/python/onnx_importer/LeakyReLU.onnx @@ -0,0 +1,15 @@ +pytorch0.3:h +" +01" LeakyRelu* +alpha +×#< torch-jit-exportZ +0 + + + +b +1 + + + +B \ No newline at end of file diff --git a/test/python/onnx_importer/import_onnx_tool.runlit b/test/python/onnx_importer/import_onnx_tool.runlit new file mode 100644 index 000000000000..45b733f9da7a --- /dev/null +++ b/test/python/onnx_importer/import_onnx_tool.runlit @@ -0,0 +1,3 @@ +# RUN: %PYTHON -m torch_mlir.tools.import_onnx %S/LeakyReLU.onnx | FileCheck %s + +# CHECK: torch.operator "onnx.LeakyRelu"