Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Refactor backend API #1869

Open
wants to merge 35 commits into
base: dev-1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
346 changes: 185 additions & 161 deletions docs/en/07-developer-guide/support_new_backend.md

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions docs/en/experimental/onnx_optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ cmake --build . -- -j$(nproc) && cmake --install .

```python
# import model_to_graph_custom_optimizer so we can hijack onnx.export
from mmdeploy.apis.onnx.optimizer import model_to_graph__custom_optimizer # noqa
from mmdeploy.ir.onnx.optimizer import model_to_graph__custom_optimizer # noqa
from mmdeploy.core import RewriterContext
from mmdeploy.apis.onnx.passes import optimize_onnx
from mmdeploy.ir.onnx.passes import optimize_onnx

# load you model here
model = create_model()
Expand Down
347 changes: 185 additions & 162 deletions docs/zh_cn/07-developer-guide/support_new_backend.md

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions docs/zh_cn/experimental/onnx_optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ cmake --build . -- -j$(nproc) && cmake --install .

```python
# import model_to_graph_custom_optimizer so we can hijack onnx.export
from mmdeploy.apis.onnx.optimizer import model_to_graph__custom_optimizer # noqa
from mmdeploy.ir.onnx.optimizer import model_to_graph__custom_optimizer # noqa
from mmdeploy.core import RewriterContext
from mmdeploy.apis.onnx.passes import optimize_onnx
from mmdeploy.ir.onnx.passes import optimize_onnx

# load you model here
model = create_model()
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/tutorial/07_write_a_plugin.md
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ engine = from_onnx(
opt_shape=[1, 1, 512, 512],
max_shape=[1, 1, 1024, 1024])))

from mmdeploy.backend.tensorrt import TRTWrapper
from mmdeploy.backend.tensorrt.wrapper import TRTWrapper

trt_model = TRTWrapper('srcnn3.engine', ['output'])

Expand Down
46 changes: 46 additions & 0 deletions mmdeploy/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import sys as _sys

if __name__ == '__main__':
# args default to the system args
console_args = _sys.argv[1:]

# extract help
help = False
if '-h' in console_args:
help = True
console_args.remove('-h')
if '--help' in console_args:
help = True
console_args.remove('--help')

# add root parser
parser = argparse.ArgumentParser(
'mmdeploy', description='MMDeploy Toolkit')
command_parsers = parser.add_subparsers(title='Commands', dest='command')
list_parser = command_parsers.add_parser(
'list', help='List available backend and task.')
show_parser = command_parsers.add_parser(
'show', help='Should information about the object.')
run_parser = command_parsers.add_parser(
'run', help='Run console tools of backend or task.')
args, remain_args = parser.parse_known_args(console_args)

# parse command
command = getattr(args, 'command', None)

if help:
remain_args = ['--help'] + remain_args
if command == 'list':
from mmdeploy.tools.console import list_command
list_command(list_parser, remain_args)
elif command == 'show':
from mmdeploy.tools.console import show_command
show_command(list_parser, remain_args)
elif command == 'run':
from mmdeploy.tools.console import run_command
run_command(list_parser, remain_args)
else:
parser.print_help()
parser.exit()
84 changes: 11 additions & 73 deletions mmdeploy/apis/onnx/export.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from functools import partial
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Dict, Optional, Sequence, Tuple, Union

import torch

from mmdeploy.apis.core import PIPELINE_MANAGER
from mmdeploy.core import RewriterContext, patch_model
from mmdeploy.utils import IR, Backend, get_ir_config, get_root_logger
from .optimizer import * # noqa
from .passes import optimize_onnx
from mmdeploy.ir.onnx import ONNXManager
from mmdeploy.utils import Backend


@PIPELINE_MANAGER.register_pipeline()
Expand Down Expand Up @@ -70,75 +66,17 @@ def export(model: torch.nn.Module,
"""
output_path = output_path_prefix + '.onnx'

logger = get_root_logger()
logger.info(f'Export PyTorch model to ONNX: {output_path}.')

def _add_or_update(cfg: dict, key: str, val: Any):
if key in cfg and isinstance(cfg[key], dict) and isinstance(val, dict):
cfg[key].update(val)
else:
cfg[key] = val

context_info = deepcopy(context_info)
deploy_cfg = context_info.pop('deploy_cfg', dict())
ir_config = dict(
type='onnx',
ONNXManager.export(
model,
args,
output_path,
input_names=input_names,
output_names=output_names,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
verbose=verbose,
keep_initializers_as_inputs=keep_initializers_as_inputs)
_add_or_update(deploy_cfg, 'ir_config', ir_config)
ir = IR.get(get_ir_config(deploy_cfg)['type'])
if isinstance(backend, Backend):
backend = backend.value
backend_config = dict(type=backend)
_add_or_update(deploy_cfg, 'backend_config', backend_config)

context_info['cfg'] = deploy_cfg
context_info['ir'] = ir
if 'backend' not in context_info:
context_info['backend'] = backend
if 'opset' not in context_info:
context_info['opset'] = opset_version

# patch model
patched_model = patch_model(model, cfg=deploy_cfg, backend=backend, ir=ir)

if 'onnx_custom_passes' not in context_info:
onnx_custom_passes = optimize_onnx if optimize else None
context_info['onnx_custom_passes'] = onnx_custom_passes
with RewriterContext(**context_info), torch.no_grad():
# patch input_metas
if input_metas is not None:
assert isinstance(
input_metas, dict
), f'Expect input_metas type is dict, get {type(input_metas)}.'
model_forward = patched_model.forward

def wrap_forward(forward):

def wrapper(*arg, **kwargs):
return forward(*arg, **kwargs)

return wrapper

patched_model.forward = wrap_forward(patched_model.forward)
patched_model.forward = partial(patched_model.forward,
**input_metas)

torch.onnx.export(
patched_model,
args,
output_path,
export_params=True,
input_names=input_names,
output_names=output_names,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
verbose=verbose)

if input_metas is not None:
patched_model.forward = model_forward
backend=backend,
const_args=input_metas,
rewrite_context=deploy_cfg,
optimize=optimize)
2 changes: 1 addition & 1 deletion mmdeploy/apis/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import mmengine

from mmdeploy.backend.openvino import ModelOptimizerOptions
from mmdeploy.backend.openvino.utils import ModelOptimizerOptions
from mmdeploy.utils import get_model_inputs
from mmdeploy.utils.config_utils import get_backend_config, get_ir_config

Expand Down
2 changes: 1 addition & 1 deletion mmdeploy/apis/snpe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.backend.snpe import from_onnx as _from_onnx
from mmdeploy.backend.snpe import is_available
from mmdeploy.backend.snpe.onnx2dlc import from_onnx as _from_onnx
from ..core import PIPELINE_MANAGER

from_onnx = PIPELINE_MANAGER.register_pipeline()(_from_onnx)
Expand Down
4 changes: 2 additions & 2 deletions mmdeploy/apis/tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
__all__ = ['is_available']

if is_available():
from mmdeploy.backend.tensorrt import from_onnx as _from_onnx
from mmdeploy.backend.tensorrt import load, save
from mmdeploy.backend.tensorrt.utils import from_onnx as _from_onnx
from mmdeploy.backend.tensorrt.utils import load, save
from_onnx = PIPELINE_MANAGER.register_pipeline()(_from_onnx)
__all__ += ['from_onnx', 'save', 'load']
try:
Expand Down
89 changes: 21 additions & 68 deletions mmdeploy/apis/torch_jit/trace.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from functools import partial
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import torch

from mmdeploy.core import RewriterContext, patch_model
from mmdeploy.utils import IR, Backend, get_ir_config, get_root_logger
from mmdeploy.ir.torchscript import export
from mmdeploy.utils import Backend
from ..core import PIPELINE_MANAGER


Expand All @@ -27,9 +25,10 @@ def trace(func: torch.nn.Module,
>>> func = create_model()
>>> inputs = get_input_tensor()
>>>
>>> jit_model = trace(
>>> trace(
>>> func,
>>> inputs,
>>> output_prefix,
>>> backend='torchscript',
>>> check_trace=False)
>>>
Expand All @@ -55,69 +54,23 @@ def trace(func: torch.nn.Module,
Returns:
torch.jit.TracedModule: The traced torch jit model.
"""
logger = get_root_logger()
logger.info('Export PyTorch model to torchscript.')
if output_path_prefix is None:
from tempfile import NamedTemporaryFile
output_path = NamedTemporaryFile(suffix='.pth').name
else:
output_path = output_path_prefix + '.pth'

def _add_or_update(cfg: dict, key: str, val: Any):
if key in cfg and isinstance(cfg[key], dict) and isinstance(val, dict):
cfg[key].update(val)
else:
cfg[key] = val

context_info = deepcopy(context_info)
deploy_cfg = context_info.pop('deploy_cfg', dict())
ir_config = dict(type='torchscript')
_add_or_update(deploy_cfg, 'ir_config', ir_config)

if isinstance(backend, Backend):
backend = backend.value
backend_config = dict(type=backend)
_add_or_update(deploy_cfg, 'backend_config', backend_config)

context_info['cfg'] = deploy_cfg
if 'backend' not in context_info:
context_info['backend'] = backend
elif context_info['backend'] != backend:
logger.warning(
f'Find backend {context_info["backend"]} in context_info.'
f' Expect {backend}.')
if 'ir' not in context_info:
context_info['ir'] = IR.TORCHSCRIPT
elif context_info['ir'] != backend:
logger.warning(f'Find ir {context_info["ir"]} in context_info.'
f' Expect {IR.TORCHSCRIPT}.')

# patch model
if isinstance(func, torch.nn.Module):
ir = IR.get(get_ir_config(deploy_cfg)['type'])
func = patch_model(func, cfg=deploy_cfg, backend=backend, ir=ir)

with RewriterContext(**context_info), torch.no_grad():

# patch input_metas
if input_metas is not None:
assert isinstance(
input_metas, dict
), f'Expect input_metas type is dict, get {type(input_metas)}.'
model_forward = func.forward
func.forward = partial(func.forward, **input_metas)

# for exporting models with weight that depends on inputs
func(*inputs) if isinstance(inputs, Sequence) \
else func(inputs)
ts_model = torch.jit.trace(
func,
inputs,
check_trace=check_trace,
check_tolerance=check_tolerance)

if input_metas is not None:
func.forward = model_forward

# save model
if output_path_prefix is not None:
output_path = output_path_prefix + '.pt'
logger.info(f'Save PyTorch model: {output_path}.')
torch.jit.save(ts_model, output_path)
export(
func,
inputs,
output_path,
backend=backend,
rewrite_context=deploy_cfg,
check_trace=check_trace,
check_tolerance=check_tolerance,
const_args=input_metas)

ts_model = torch.jit.load(output_path)

return ts_model
5 changes: 2 additions & 3 deletions mmdeploy/apis/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
__all__ = ['is_available', 'get_library_ext']

if is_available():
from mmdeploy.backend.tvm import HDF5Dataset
from mmdeploy.backend.tvm import from_onnx as _from_onnx
from mmdeploy.backend.tvm.onnx2tvm import from_onnx as _from_onnx
from_onnx = PIPELINE_MANAGER.register_pipeline()(_from_onnx)

__all__ += ['from_onnx', 'HDF5Dataset']
__all__ += ['from_onnx']
40 changes: 33 additions & 7 deletions mmdeploy/apis/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,38 @@ def to_backend(backend_name: str,
Returns:
Sequence[str]: Backend files.
"""
import os.path as osp
from copy import deepcopy

from mmdeploy.backend.base import get_backend_manager
from mmdeploy.utils import get_model_inputs
backend_mgr = get_backend_manager(backend_name)
return backend_mgr.to_backend(
ir_files=ir_files,
work_dir=work_dir,
deploy_cfg=deploy_cfg,
log_level=log_level,
device=device,
**kwargs)

model_inputs = get_model_inputs(deploy_cfg)
assert model_inputs is None or len(model_inputs) == 0 or len(
model_inputs) == len(ir_files)
backend_files = []
for idx, ir_file in enumerate(ir_files):
if isinstance(model_inputs, (list, tuple)) and len(model_inputs) > 0:
curr_deploy_cfg = deepcopy(deploy_cfg)
curr_deploy_cfg['backend_config']['model_inputs'] = [
model_inputs[idx]
]
else:
curr_deploy_cfg = deploy_cfg

file_name = osp.splitext(osp.split(ir_file)[1])[0]
param = backend_mgr.build_param_from_config(
curr_deploy_cfg,
work_dir=work_dir,
backend_files=[file_name],
device=device,
**kwargs)

backend_mgr.to_backend_from_param(ir_file, param)
backend_file = param.get_model_files()
if isinstance(backend_file, str):
backend_file = [backend_file]
backend_files += backend_file

return backend_files
2 changes: 1 addition & 1 deletion mmdeploy/apis/vacc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

if is_available():
try:
from mmdeploy.backend.vacc import from_onnx as _from_onnx
from mmdeploy.backend.vacc.onnx2vacc import from_onnx as _from_onnx
from_onnx = PIPELINE_MANAGER.register_pipeline()(_from_onnx)
__all__ += ['from_onnx']
except Exception:
Expand Down
Loading