Skip to content

Commit 5ffc03d

Browse files
authored
Merge branch 'dev' into pythonicworkflow
2 parents c026441 + 13b96ae commit 5ffc03d

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

monai/networks/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -712,9 +712,10 @@ def convert_to_onnx(
712712
onnx_model = onnx.load(f)
713713

714714
if do_constant_folding and polygraphy_imported:
715-
from polygraphy.backend.onnx.loader import fold_constants
715+
from polygraphy.backend.onnx.loader import fold_constants, save_onnx
716716

717-
fold_constants(onnx_model, size_threshold=constant_size_threshold)
717+
onnx_model = fold_constants(onnx_model, size_threshold=constant_size_threshold)
718+
save_onnx(onnx_model, f)
718719

719720
if verify:
720721
if isinstance(inputs, dict):

tests/test_trt_compile.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def tearDown(self):
6161
if current_device != self.gpu_device:
6262
torch.cuda.set_device(self.gpu_device)
6363

64-
@unittest.skipUnless(torch_trt_imported, "torch_tensorrt is required")
64+
# @unittest.skipUnless(torch_trt_imported, "torch_tensorrt is required")
6565
def test_handler(self):
6666
from ignite.engine import Engine
6767

@@ -74,7 +74,7 @@ def test_handler(self):
7474

7575
with tempfile.TemporaryDirectory() as tempdir:
7676
engine = Engine(lambda e, b: None)
77-
args = {"method": "torch_trt"}
77+
args = {"method": "onnx", "dynamic_batchsize": [1, 4, 8]}
7878
TrtHandler(net1, tempdir + "/trt_handler", args=args).attach(engine)
7979
engine.run([0] * 8, max_epochs=1)
8080
self.assertIsNotNone(net1._trt_compiler)
@@ -86,7 +86,11 @@ def test_lists(self):
8686
model = ListAdd().cuda()
8787

8888
with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:
89-
args = {"output_lists": [[-1], [2], []], "export_args": {"dynamo": False, "verbose": True}}
89+
args = {
90+
"output_lists": [[-1], [2], []],
91+
"export_args": {"dynamo": False, "verbose": True},
92+
"dynamic_batchsize": [1, 4, 8],
93+
}
9094
x = torch.randn(1, 16).to("cuda")
9195
y = torch.randn(1, 16).to("cuda")
9296
z = torch.randn(1, 16).to("cuda")

0 commit comments

Comments
 (0)