Skip to content
Merged
Show file tree
Hide file tree
Changes from 115 commits
Commits
Show all changes
134 commits
Select commit Hold shift + click to select a range
054a2b8
Added TRTWrapper
borisfom Aug 5, 2024
3ab9c83
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2024
4ec1d3b
Merge branch 'dev' into trt-wrappers
KumoLiu Aug 5, 2024
fe71030
Addressing code review comments, adding docustrings, cleanup
borisfom Aug 5, 2024
6a9727f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2024
29d9725
Added TRT 10.3RC to Dockerfile
borisfom Aug 6, 2024
5b8b4f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
f31d6dd
Workaround for format check
borisfom Aug 6, 2024
9303c32
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
c1d0b19
More format check workarounds
borisfom Aug 6, 2024
63c4b70
More format check workarounds
borisfom Aug 6, 2024
9a3d6a6
More format check workarounds
borisfom Aug 6, 2024
8bf0300
Using optional exports for trt_utils
borisfom Aug 6, 2024
c03e49b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
39c94c2
Fixing lint errors
borisfom Aug 6, 2024
35dffcc
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Aug 6, 2024
9d867a7
Format fixed
borisfom Aug 6, 2024
6e2733a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
848a42d
Fixing flake errors
borisfom Aug 6, 2024
9ade6af
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Aug 6, 2024
cf2c3b1
Fixing CI
borisfom Aug 6, 2024
e8b51f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
ddb5bc8
Fixed mypy, Engine refactor
borisfom Aug 7, 2024
79014d7
Merge branch 'dev' into trt-wrappers
yiheng-wang-nv Aug 7, 2024
511081f
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 7, 2024
b188237
Merged cast_utils, copyrights fixed.
borisfom Aug 8, 2024
60cdd74
Added unit test
borisfom Aug 8, 2024
778a44a
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 9, 2024
0ab5d26
TRTWrapper moved to networks
borisfom Aug 9, 2024
a948bfb
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 9, 2024
3a72c76
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 9, 2024
7d449f5
Refactored TRTWrapper args
borisfom Aug 10, 2024
6846fd4
Added docstring for precision
borisfom Aug 10, 2024
d598590
Fixed comments, reordered args
borisfom Aug 11, 2024
9109d3f
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 12, 2024
517c111
Reduced test assert accuracy
borisfom Aug 12, 2024
4739756
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 14, 2024
ed0d93d
Addressing code review comments
borisfom Aug 14, 2024
2ec8e53
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 15, 2024
fdcf118
Added Torch-TRT option, cleaned up engine save method
borisfom Aug 15, 2024
1009dc5
Added trt_wrap adapter
borisfom Aug 16, 2024
763f769
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 16, 2024
fd679c0
Refined trt_wrap
borisfom Aug 16, 2024
dc13b52
Used tempdir for ONNX
borisfom Aug 17, 2024
779de92
Refactored trt wrapper, added trt handler
borisfom Aug 18, 2024
6504dc9
Adjusted refactor for use in config
borisfom Aug 18, 2024
c1be72c
Added fold constant threshold param
borisfom Aug 20, 2024
0f16b8b
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 20, 2024
5c495b6
Logger refactoring
borisfom Aug 20, 2024
5d1ebc2
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 20, 2024
48b85ce
Addressing code review comments
borisfom Aug 22, 2024
1244c49
Added multiple submodules option to trt_wrap
borisfom Aug 22, 2024
a603f13
Added polygraphy to more places, torch-tensorrt option debugging
borisfom Aug 23, 2024
f5be0cc
Renamed trt_wrap -> trt_compile
borisfom Aug 23, 2024
b96ebb4
Reformatted for CI
borisfom Aug 23, 2024
73be701
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 23, 2024
85140e2
Fixed alias issue
borisfom Aug 23, 2024
fa4c182
Fixed base in Dockerfile
borisfom Aug 23, 2024
78a3ef3
Fixed CI test failures
borisfom Aug 23, 2024
267c125
Addressed code review comments
borisfom Aug 23, 2024
9adc035
Added dictionary return option
borisfom Aug 26, 2024
a017fcd
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 26, 2024
7f1c0c1
Fixed return_dict issue
borisfom Aug 26, 2024
a242a64
Implemented https://github.com/Project-MONAI/MONAI/issues/8044
borisfom Aug 26, 2024
5afc912
Generalizing merge logic, adding test case and doc
borisfom Aug 27, 2024
ceff018
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2024
e294968
Addressing code review comments
borisfom Aug 27, 2024
55cf7fa
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Aug 27, 2024
b6d9179
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2024
652448a
doc build fixed
borisfom Aug 28, 2024
b793eb2
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Aug 28, 2024
5c4f63a
Fixed formatting
borisfom Aug 28, 2024
dd91183
Fixed formatting
borisfom Aug 28, 2024
c41cb5a
Updated base container to 24.08
borisfom Aug 28, 2024
7e440fc
Renaming trt_wrapper -> trt_compiler, adding TRT handler test
borisfom Aug 28, 2024
329d024
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2024
84de860
fixing CI error
borisfom Aug 28, 2024
875d1a8
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Aug 28, 2024
b84cec4
Fixing min test error, addressing comments
borisfom Aug 28, 2024
9481d9f
optional propagation of dynamo arg fixed, onnx_graphsurgeon package a…
borisfom Aug 28, 2024
6a11581
add vista test cases
yiheng-wang-nv Aug 28, 2024
6e8bd6b
Merge branch 'dev' into trt-wrappers
yiheng-wang-nv Aug 28, 2024
3d221cb
Merge branch 'dev' into trt-wrappers
KumoLiu Aug 28, 2024
792a721
Code review input addressed
borisfom Aug 29, 2024
73ac717
Fixed torch-tensorrt path of trt_compile, added test
borisfom Aug 29, 2024
ea879f2
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Aug 29, 2024
1e7e76d
Fixing tests
borisfom Aug 29, 2024
47e676e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2024
1cb49c3
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Aug 29, 2024
dd4d2d6
Merge branch 'dev' into trt-wrappers
binliunls Aug 29, 2024
6b47a8b
Merge branch 'dev' into trt-wrappers
KumoLiu Aug 31, 2024
80d3928
Merge branch 'trt-wrappers' of github.com:borisfom/MONAI into trt-wra…
borisfom Sep 3, 2024
47823d1
Merge remote-tracking branch 'origin/dev' into trt-wrappers
borisfom Sep 3, 2024
c126e67
Fixing TRT 8.x compatibility
borisfom Sep 3, 2024
5645157
Improved diagnostic, skip trt test if < 10.3
borisfom Sep 3, 2024
9e98f66
Merge branch 'dev' into trt_compiler_fixes
KumoLiu Sep 4, 2024
72a4c3d
trt_compile post-fixes
borisfom Oct 3, 2024
b5f8ff2
Merge remote-tracking branch 'origin/dev' into trt_compiler_fixes
borisfom Oct 3, 2024
bf61b48
exporting controlnet
borisfom Oct 9, 2024
3d16f86
Merge branch 'trt_compiler_fixes' of github.com:borisfom/MONAI into t…
borisfom Oct 9, 2024
6297b45
Working controlnet TRT
borisfom Oct 10, 2024
f00fea4
Reformat
borisfom Oct 10, 2024
57fbcf0
Working TRT for MAISI
borisfom Oct 12, 2024
a004bc5
Working dynamic batch with sequences
borisfom Oct 16, 2024
cee7299
Merge remote-tracking branch 'origin/dev' into maisi-trt
borisfom Oct 16, 2024
adf9bc9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 16, 2024
4002d9d
Merge fixed and style
borisfom Oct 16, 2024
d8407c9
Merge branch 'maisi-trt' of github.com:borisfom/MONAI into maisi-trt
borisfom Oct 16, 2024
a37fd53
Merge remote-tracking branch 'origin/dev' into maisi-trt
borisfom Oct 19, 2024
43ea6a0
Added output_lists option
borisfom Oct 21, 2024
c1791f6
Bugfix for multiple initialization
borisfom Oct 30, 2024
ce73be5
Merge remote-tracking branch 'origin/dev' into maisi-trt
borisfom Oct 30, 2024
14912d9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 30, 2024
30b8bcf
Adding Torch patch
borisfom Oct 31, 2024
7adf804
Merge branch 'maisi-trt' of github.com:borisfom/MONAI into maisi-trt
borisfom Oct 31, 2024
8baaa74
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2024
214def9
Fixing torch_trt compile and test case
borisfom Nov 1, 2024
2452c97
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2024
d07a57f
Added rename table for TRT engine, test for output lists
borisfom Nov 2, 2024
c6e11bb
Merge branch 'maisi-trt' of github.com:borisfom/MONAI into maisi-trt
borisfom Nov 2, 2024
6eeed4b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2024
4ae6372
flake
borisfom Nov 7, 2024
c7a1fb7
Merge remote-tracking branch 'origin/dev' into maisi-trt
borisfom Nov 8, 2024
a1d243f
Fixed mypy
borisfom Nov 8, 2024
7b10662
style
borisfom Nov 8, 2024
d367a01
mypy
borisfom Nov 8, 2024
959e237
fixing unroll
borisfom Nov 13, 2024
21a95bd
Merge remote-tracking branch 'origin/dev' into maisi-trt
borisfom Nov 13, 2024
723ab69
Merge branch 'dev' into maisi-trt
KumoLiu Nov 13, 2024
9c78ee3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2024
03cb981
fix format
KumoLiu Nov 13, 2024
dd12d09
Merge branch 'dev' into maisi-trt
KumoLiu Nov 13, 2024
5b51f50
Merge remote-tracking branch 'boris/maisi-trt' into maisi-trt
KumoLiu Nov 14, 2024
5f34925
fix ci
KumoLiu Nov 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# To build with a different base image
# please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag.
ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.08-py3
ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.10-py3
FROM ${PYTORCH_IMAGE}

LABEL maintainer="[email protected]"
Expand Down Expand Up @@ -41,6 +41,10 @@ RUN cp /tmp/requirements.txt /tmp/req.bak \
COPY LICENSE CHANGELOG.md CODE_OF_CONDUCT.md CONTRIBUTING.md README.md versioneer.py setup.py setup.cfg runtests.sh MANIFEST.in ./
COPY tests ./tests
COPY monai ./monai

# TODO: remove this line and torch.patch for 24.11
RUN patch -R -d /usr/local/lib/python3.10/dist-packages/torch/onnx/ < ./monai/torch.patch

RUN BUILD_MONAI=1 FORCE_CUDA=1 python setup.py develop \
&& rm -rf build __pycache__

Expand All @@ -57,4 +61,6 @@ RUN apt-get update \
# append /opt/tools to runtime path for NGC CLI to be accessible from all file system locations
ENV PATH=${PATH}:/opt/tools
ENV POLYGRAPHY_AUTOINSTALL_DEPS=1


WORKDIR /opt/monai
1 change: 0 additions & 1 deletion monai/networks/nets/vista3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,6 @@ def forward(self, src: torch.Tensor, class_vector: torch.Tensor):
# [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension.
masks_embedding = class_embedding.squeeze() @ src.view(b, c, h * w * d)
masks_embedding = masks_embedding.view(b, -1, h, w, d).transpose(0, 1)

return masks_embedding, class_embedding


Expand Down
192 changes: 145 additions & 47 deletions monai/networks/trt_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections import OrderedDict
from pathlib import Path
from types import MethodType
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Tuple, Union

import torch

Expand Down Expand Up @@ -134,6 +134,9 @@ def __init__(self, plan_path, logger=None):
self.output_names.append(binding)
dtype = dtype_dict[self.engine.get_tensor_dtype(binding)]
self.dtypes.append(dtype)
self.logger.info(
f"Loaded TensorRT engine: {self.plan_path}.\nInputs: {self.input_names}\nOutputs: {self.output_names}"
)

def allocate_buffers(self, device):
"""
Expand Down Expand Up @@ -163,7 +166,8 @@ def set_inputs(self, feed_dict, stream):
last_profile = self.cur_profile

def try_set_inputs():
for binding, t in feed_dict.items():
for binding in self.input_names:
t = feed_dict[binding]
if t is not None:
t = t.contiguous()
shape = t.shape
Expand All @@ -180,7 +184,8 @@ def try_set_inputs():
raise
self.cur_profile = next_profile
ctx.set_optimization_profile_async(self.cur_profile, stream)

except Exception:
raise
left = ctx.infer_shapes()
assert len(left) == 0

Expand Down Expand Up @@ -217,6 +222,72 @@ def infer(self, stream, use_cuda_graph=False):
return self.tensors


def unroll_input(input_names, input_example):
# Simulate list/tuple unrolling during ONNX export
unrolled_input = {}
for name in input_names:
val = input_example[name]
if val is not None:
if isinstance(val, list | tuple):
for i in range(len(val)):
unrolled_input[f"{name}_{i}"] = val[i]
else:
unrolled_input[name] = val
return unrolled_input


def parse_groups(
ret: List[torch.Tensor], output_lists: List[int]
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]], ...]:
"""
Implements parsing of 'output_lists' arg of trt_compile().

Args:
ret: plain list of Tensors

output_lists: list of output group sizes: to form some Lists/Tuples out of 'ret' List, this will be a list
of group dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list.
Format: [[group_n] | [], ...]
[] or group_n == 0 : next output from ret is a scalar
group_n > 0 : next output from ret is a list of group_n length
group_n == -1: next output is a dynamic list. This entry can be at any
position in output_lists, but can appear only once.
Returns:
Tuple of Union[torch.Tensor, List[torch.Tensor]], according to the grouping in output_lists

"""
groups = []
cur = 0
for l in range(len(output_lists)):
gl = output_lists[l]
assert len(gl) == 0 or len(gl) == 1
if len(gl) == 0 or gl[0] == 0:
groups.append(ret[cur])
cur = cur + 1
elif gl[0] > 0:
groups.append(ret[cur : cur + gl[0]])
cur = cur + gl[0]
elif gl[0] == -1:
rev_groups = []
rcur = len(ret)
for rl in range(len(output_lists) - 1, l, -1):
rgl = output_lists[rl]
assert len(rgl) == 0 or len(rgl) == 1
if len(rgl) == 0 or rgl[0] == 0:
rcur = rcur - 1
rev_groups.append(ret[rcur])
elif rgl[0] > 0:
rcur = rcur - rgl[0]
rev_groups.append(ret[rcur : rcur + rgl[0]])
else:
raise ValueError("Two -1 lists in output")
groups.append(ret[cur:rcur])
rev_groups.reverse()
groups.extend(rev_groups)
break
return tuple(groups)


class TrtCompiler:
"""
This class implements:
Expand All @@ -233,13 +304,15 @@ def __init__(
method="onnx",
input_names=None,
output_names=None,
output_lists=None,
export_args=None,
build_args=None,
input_profiles=None,
dynamic_batchsize=None,
use_cuda_graph=False,
timestamp=None,
fallback=False,
forward_override=None,
logger=None,
):
"""
Expand All @@ -255,6 +328,8 @@ def __init__(
'torch_trt' may not work for some nets. Also AMP must be turned off for it to work.
input_names: Optional list of input names. If None, will be read from the function signature.
output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary.
output_lists: Optional list of output group sizes: when forward() returns Lists/Tuples, this will be a list
of their dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list.
export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details.
build_args: Optional args to pass to TRT builder. See polygraphy.Config for details.
input_profiles: Optional list of profiles for TRT builder and ONNX export.
Expand All @@ -279,6 +354,7 @@ def __init__(
self.method = method
self.return_dict = output_names is not None
self.output_names = output_names or []
self.output_lists = output_lists or []
self.profiles = input_profiles or []
self.dynamic_batchsize = dynamic_batchsize
self.export_args = export_args or {}
Expand All @@ -289,11 +365,19 @@ def __init__(
self.disabled = False

self.logger = logger or get_logger("monai.networks.trt_compiler")
self.argspec = inspect.getfullargspec(model.forward)

# Normally we read input_names from forward() but can be overridden
if input_names is None:
argspec = inspect.getfullargspec(model.forward)
input_names = argspec.args[1:]
input_names = self.argspec.args[1:]
self.defaults = {}
if self.argspec.defaults is not None:
for i in range(len(self.argspec.defaults)):
d = self.argspec.defaults[-i - 1]
if d is not None:
d = torch.tensor(d).cuda()
self.defaults[self.argspec.args[-i - 1]] = d

self.input_names = input_names
self.old_forward = model.forward

Expand All @@ -314,9 +398,9 @@ def _load_engine(self):
"""
try:
self.engine = TRTEngine(self.plan_path, self.logger)
self.input_names = self.engine.input_names
self.logger.info(f"Engine loaded, inputs:{self.engine.input_names}")
except Exception as e:
self.logger.debug(f"Exception while loading the engine:\n{e}")
self.logger.info(f"Exception while loading the engine:\n{e}")

def forward(self, model, argv, kwargs):
"""
Expand All @@ -329,18 +413,22 @@ def forward(self, model, argv, kwargs):
Returns: Passing through wrapped module's forward() return value(s)

"""
args = self.defaults
args.update(kwargs)
if len(argv) > 0:
args.update(self._inputs_to_dict(argv))

if self.engine is None and not self.disabled:
# Restore original forward for export
new_forward = model.forward
model.forward = self.old_forward
try:
self._load_engine()
if self.engine is None:
build_args = kwargs.copy()
if len(argv) > 0:
build_args.update(self._inputs_to_dict(argv))
self._build_and_save(model, build_args)
# This will reassign input_names from the engine
build_args = args.copy()
with torch.no_grad():
self._build_and_save(model, build_args)
# This will reassign input_names from the engine
self._load_engine()
assert self.engine is not None
except Exception as e:
Expand All @@ -355,31 +443,30 @@ def forward(self, model, argv, kwargs):
del param
# Call empty_cache to release GPU memory
torch.cuda.empty_cache()
# restore TRT hook
model.forward = new_forward
# Run the engine
try:
if len(argv) > 0:
kwargs.update(self._inputs_to_dict(argv))
argv = ()

if self.engine is not None:
# forward_trt is not thread safe as we do not use per-thread execution contexts
with lock_sm:
device = torch.cuda.current_device()
stream = torch.cuda.Stream(device=device)
self.engine.set_inputs(kwargs, stream.cuda_stream)
self.engine.set_inputs(unroll_input(self.input_names, args), stream.cuda_stream)
self.engine.allocate_buffers(device=device)
# Need this to synchronize with Torch stream
stream.wait_stream(torch.cuda.current_stream())
ret = self.engine.infer(stream.cuda_stream, use_cuda_graph=self.use_cuda_graph)
# if output_names is not None, return dictionary
if not self.return_dict:
ret = list(ret.values())
if len(ret) == 1:
if self.output_lists:
ret = parse_groups(ret, self.output_lists)
elif len(ret) == 1:
ret = ret[0]
return ret
except Exception as e:
if model is not None:
if self.fallback:
self.logger.info(f"Exception: {e}\nFalling back to Pytorch ...")
else:
raise e
Expand All @@ -391,16 +478,11 @@ def _onnx_to_trt(self, onnx_path):
"""

profiles = []
if self.profiles:
for input_profile in self.profiles:
if isinstance(input_profile, Profile):
profiles.append(input_profile)
else:
p = Profile()
for name, dims in input_profile.items():
assert len(dims) == 3
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
profiles.append(p)
for profile in self.profiles:
p = Profile()
for id, val in profile.items():
p.add(id, min=val[0], opt=val[1], max=val[2])
profiles.append(p)

build_args = self.build_args.copy()
build_args["tf32"] = self.precision != "fp32"
Expand All @@ -425,7 +507,7 @@ def _build_and_save(self, model, input_example):
return

export_args = self.export_args

engine_bytes = None
add_casts_around_norms(model)

if self.method == "torch_trt":
Expand All @@ -447,7 +529,7 @@ def get_torch_trt_input(input_shape, dynamic_batchsize):
engine_bytes = torch_tensorrt.convert_method_to_trt_engine(
ir_model,
"forward",
inputs=tt_inputs,
arg_inputs=tt_inputs,
ir="torchscript",
enabled_precisions=enabled_precisions,
**export_args,
Expand All @@ -459,33 +541,47 @@ def get_torch_trt_input(input_shape, dynamic_batchsize):
raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!")
if len(dbs) != 3:
raise ValueError("dynamic_batchsize has to have len ==3 ")
profiles = {}
profile = {}
for id, val in input_example.items():
sh = val.shape[1:]
profiles[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]]
self.profiles = [profiles]

if len(self.profiles) > 0:
export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)})
def add_profile(id, val):
sh = val.shape
if len(sh) > 0:
sh = sh[1:]
profile[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]]

if isinstance(val, list | tuple):
for i in range(len(val)):
add_profile(f"{id}_{i}", val[i])
elif isinstance(val, torch.Tensor):
add_profile(id, val)
self.profiles = [profile]

self.dynamic_axes = get_dynamic_axes(self.profiles)

if len(self.dynamic_axes) > 0:
export_args.update({"dynamic_axes": self.dynamic_axes})

# Use temporary directory for easy cleanup in case of external weights
with tempfile.TemporaryDirectory() as tmpdir:
onnx_path = Path(tmpdir) / "model.onnx"
unrolled_input = unroll_input(self.input_names, input_example)
onnx_path = str(Path(tmpdir) / "model.onnx")
self.logger.info(
f"Exporting to {onnx_path}:\n\toutput_names={self.output_names}\n\texport args: {export_args}"
f"Exporting to {onnx_path}:\nunrolled_inputs={list(unrolled_input.keys())}\n"
+ f"output_names={self.output_names}\ninput_names={self.input_names}\nexport args: {export_args}"
)
convert_to_onnx(
model,
input_example,
filename=str(onnx_path),
input_names=self.input_names,
filename=onnx_path,
input_names=list(unrolled_input.keys()),
output_names=self.output_names,
**export_args,
)
self.logger.info("Export to ONNX successful.")
engine_bytes = self._onnx_to_trt(str(onnx_path))

open(self.plan_path, "wb").write(engine_bytes)
engine_bytes = self._onnx_to_trt(onnx_path)
if engine_bytes:
open(self.plan_path, "wb").write(engine_bytes)


def trt_forward(self, *argv, **kwargs):
Expand Down Expand Up @@ -540,9 +636,11 @@ def trt_compile(
args["timestamp"] = timestamp

def wrap(model, path):
wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args)
model._trt_compiler = wrapper
model.forward = MethodType(trt_forward, model)
if not hasattr(model, "_trt_compiler"):
model.orig_forward = model.forward
wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args)
model._trt_compiler = wrapper
model.forward = MethodType(trt_forward, model)

def find_sub(parent, submodule):
idx = submodule.find(".")
Expand Down
Loading