Skip to content

Commit d07a57f

Browse files
committed
Added rename table for TRT engine, test for output lists
Signed-off-by: Boris Fomitchev <[email protected]>
1 parent 214def9 commit d07a57f

File tree

3 files changed

+52
-12
lines changed

3 files changed

+52
-12
lines changed

monai/networks/trt_compiler.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def set_inputs(self, feed_dict, stream):
167167

168168
def try_set_inputs():
169169
for binding in self.input_names:
170-
t = feed_dict[binding]
170+
t = feed_dict.get(self.input_table[binding], None)
171171
if t is not None:
172172
t = t.contiguous()
173173
shape = t.shape
@@ -222,6 +222,10 @@ def infer(self, stream, use_cuda_graph=False):
222222
return self.tensors
223223

224224

225+
def make_tensor(d):
226+
return d if isinstance(d, torch.Tensor) else torch.tensor(d).cuda()
227+
228+
225229
def unroll_input(input_names, input_example):
226230
# Simulate list/tuple unrolling during ONNX export
227231
unrolled_input = {}
@@ -230,9 +234,9 @@ def unroll_input(input_names, input_example):
230234
if val is not None:
231235
if isinstance(val, list | tuple):
232236
for i in range(len(val)):
233-
unrolled_input[f"{name}_{i}"] = val[i]
237+
unrolled_input[f"{name}_{i}"] = make_tensor(val[i])
234238
else:
235-
unrolled_input[name] = val
239+
unrolled_input[name] = make_tensor(val)
236240
return unrolled_input
237241

238242

@@ -375,8 +379,8 @@ def __init__(
375379
for i in range(len(self.argspec.defaults)):
376380
d = self.argspec.defaults[-i - 1]
377381
if d is not None:
378-
d = torch.tensor(d).cuda()
379-
self.defaults[self.argspec.args[-i - 1]] = d
382+
d = make_tensor(d)
383+
self.defaults[self.argspec.args[-i - 1]] = d
380384

381385
self.input_names = input_names
382386
self.old_forward = model.forward
@@ -398,7 +402,16 @@ def _load_engine(self):
398402
"""
399403
try:
400404
self.engine = TRTEngine(self.plan_path, self.logger)
401-
self.logger.info(f"Engine loaded, inputs:{self.engine.input_names}")
405+
# Make sure we have names correct
406+
input_table = {}
407+
for name in self.engine.input_names:
408+
if name.startswith("__") and name not in self.input_names:
409+
orig_name = name[2:]
410+
else:
411+
orig_name = name
412+
input_table[name] = orig_name
413+
self.engine.input_table = input_table
414+
self.logger.info(f"Engine loaded, inputs:{self.engine.input_table}")
402415
except Exception as e:
403416
self.logger.info(f"Exception while loading the engine:\n{e}")
404417

monai/networks/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,7 @@ def convert_to_onnx(
703703
onnx_inputs,
704704
f=f,
705705
input_names=input_names,
706-
output_names=output_names,
706+
output_names=output_names or None,
707707
dynamic_axes=dynamic_axes,
708708
opset_version=opset_version,
709709
do_constant_folding=do_constant_folding,

tests/test_trt_compile.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import tempfile
1515
import unittest
16+
from typing import List
1617

1718
import torch
1819
from parameterized import parameterized
@@ -32,6 +33,19 @@
3233
TEST_CASE_2 = ["fp16"]
3334

3435

36+
class ListAdd(torch.nn.Module):
37+
def __init__(self):
38+
super().__init__()
39+
40+
def forward(self, x: List[torch.Tensor], y: torch.Tensor, z: torch.Tensor, bs: float = float(0.1)):
41+
y1 = y.clone()
42+
x1 = x.copy()
43+
z1 = z + y
44+
for xi in x:
45+
y1 = y1 + xi + bs
46+
return x1, [y1, z1], y1 + z1
47+
48+
3549
@skip_if_windows
3650
@skip_if_no_cuda
3751
@skip_if_quick
@@ -68,6 +82,23 @@ def test_handler(self):
6882
net1.forward(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device="cuda"))
6983
self.assertIsNotNone(net1._trt_compiler.engine)
7084

85+
def test_lists(self):
86+
model = ListAdd().cuda()
87+
88+
with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:
89+
args = {"output_lists": [[-1], [2], []], "export_args": {"dynamo": False, "verbose": True}}
90+
x = torch.randn(1, 16).to("cuda")
91+
y = torch.randn(1, 16).to("cuda")
92+
z = torch.randn(1, 16).to("cuda")
93+
input_example = ([x, y, z], y.clone(), z.clone())
94+
output_example = model(*input_example)
95+
trt_compile(model, f"{tmpdir}/test_lists", args=args)
96+
self.assertIsNone(model._trt_compiler.engine)
97+
trt_output = model(*input_example)
98+
# Check that lazy TRT build succeeded
99+
self.assertIsNotNone(model._trt_compiler.engine)
100+
torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01)
101+
71102
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
72103
@unittest.skipUnless(has_sam, "Requires SAM installation")
73104
def test_cell_sam_wrapper_value(self, precision):
@@ -76,11 +107,7 @@ def test_cell_sam_wrapper_value(self, precision):
76107
model.eval()
77108
input_example = torch.randn(1, 3, 128, 128).to("cuda")
78109
output_example = model(input_example)
79-
trt_compile(
80-
model,
81-
f"{tmpdir}/test_cell_sam_wrapper_trt_compile",
82-
args={"precision": precision},
83-
)
110+
trt_compile(model, f"{tmpdir}/test_cell_sam_wrapper_trt_compile", args={"precision": precision})
84111
self.assertIsNone(model._trt_compiler.engine)
85112
trt_output = model(input_example)
86113
# Check that lazy TRT build succeeded

0 commit comments

Comments
 (0)