1919
2020from monai .handlers import TrtHandler
2121from monai .networks import trt_compile
22- from monai .networks .nets import UNet , cell_sam_wrapper , vista3d132
22+ from monai .networks .nets import cell_sam_wrapper , vista3d132
2323from monai .utils import min_version , optional_import
24- from tests .utils import (
25- SkipIfAtLeastPyTorchVersion ,
26- SkipIfBeforeComputeCapabilityVersion ,
27- skip_if_no_cuda ,
28- skip_if_quick ,
29- skip_if_windows ,
30- )
24+ from tests .utils import SkipIfBeforeComputeCapabilityVersion , skip_if_no_cuda , skip_if_quick , skip_if_windows
3125
3226trt , trt_imported = optional_import ("tensorrt" , "10.1.0" , min_version )
27+ torch_tensorrt , torch_trt_imported = optional_import ("torch_tensorrt" )
3328polygraphy , polygraphy_imported = optional_import ("polygraphy" )
3429build_sam_vit_b , has_sam = optional_import ("segment_anything.build_sam" , name = "build_sam_vit_b" )
3530
3631TEST_CASE_1 = ["fp32" ]
3732TEST_CASE_2 = ["fp16" ]
3833
3934
35+ class ListAdd (torch .nn .Module ):
36+ def __init__ (self ):
37+ super ().__init__ ()
38+
39+ def forward (self , x : list [torch .Tensor ], y : torch .Tensor , z : torch .Tensor , bs : float = 0.1 ):
40+ y1 = y .clone ()
41+ x1 = x .copy ()
42+ z1 = z + y
43+ for xi in x :
44+ y1 = y1 + xi + bs
45+ return x1 , [y1 , z1 ], y1 + z1
46+
47+
4048@skip_if_windows
4149@skip_if_no_cuda
4250@skip_if_quick
@@ -53,7 +61,7 @@ def tearDown(self):
5361 if current_device != self .gpu_device :
5462 torch .cuda .set_device (self .gpu_device )
5563
56- @ SkipIfAtLeastPyTorchVersion (( 2 , 4 , 1 ) )
64+ # @unittest.skipUnless(torch_trt_imported, "torch_tensorrt is required" )
5765 def test_handler (self ):
5866 from ignite .engine import Engine
5967
@@ -66,37 +74,31 @@ def test_handler(self):
6674
6775 with tempfile .TemporaryDirectory () as tempdir :
6876 engine = Engine (lambda e , b : None )
69- args = {"method" : "torch_trt" }
77+ args = {"method" : "onnx" , "dynamic_batchsize" : [ 1 , 4 , 8 ] }
7078 TrtHandler (net1 , tempdir + "/trt_handler" , args = args ).attach (engine )
7179 engine .run ([0 ] * 8 , max_epochs = 1 )
7280 self .assertIsNotNone (net1 ._trt_compiler )
7381 self .assertIsNone (net1 ._trt_compiler .engine )
7482 net1 .forward (torch .tensor ([[0.0 , 1.0 ], [1.0 , 2.0 ]], device = "cuda" ))
7583 self .assertIsNotNone (net1 ._trt_compiler .engine )
7684
77- @parameterized .expand ([TEST_CASE_1 , TEST_CASE_2 ])
78- def test_unet_value (self , precision ):
79- model = UNet (
80- spatial_dims = 3 ,
81- in_channels = 1 ,
82- out_channels = 2 ,
83- channels = (2 , 2 , 4 , 8 , 4 ),
84- strides = (2 , 2 , 2 , 2 ),
85- num_res_units = 2 ,
86- norm = "batch" ,
87- ).cuda ()
85+ def test_lists (self ):
86+ model = ListAdd ().cuda ()
87+
8888 with torch .no_grad (), tempfile .TemporaryDirectory () as tmpdir :
89- model .eval ()
90- input_example = torch .randn (2 , 1 , 96 , 96 , 96 ).cuda ()
91- output_example = model (input_example )
92- args : dict = {"builder_optimization_level" : 1 }
93- trt_compile (
94- model ,
95- f"{ tmpdir } /test_unet_trt_compile" ,
96- args = {"precision" : precision , "build_args" : args , "dynamic_batchsize" : [1 , 4 , 8 ]},
97- )
89+ args = {
90+ "output_lists" : [[- 1 ], [2 ], []],
91+ "export_args" : {"dynamo" : False , "verbose" : True },
92+ "dynamic_batchsize" : [1 , 4 , 8 ],
93+ }
94+ x = torch .randn (1 , 16 ).to ("cuda" )
95+ y = torch .randn (1 , 16 ).to ("cuda" )
96+ z = torch .randn (1 , 16 ).to ("cuda" )
97+ input_example = ([x , y , z ], y .clone (), z .clone ())
98+ output_example = model (* input_example )
99+ trt_compile (model , f"{ tmpdir } /test_lists" , args = args )
98100 self .assertIsNone (model ._trt_compiler .engine )
99- trt_output = model (input_example )
101+ trt_output = model (* input_example )
100102 # Check that lazy TRT build succeeded
101103 self .assertIsNotNone (model ._trt_compiler .engine )
102104 torch .testing .assert_close (trt_output , output_example , rtol = 0.01 , atol = 0.01 )
@@ -109,11 +111,7 @@ def test_cell_sam_wrapper_value(self, precision):
109111 model .eval ()
110112 input_example = torch .randn (1 , 3 , 128 , 128 ).to ("cuda" )
111113 output_example = model (input_example )
112- trt_compile (
113- model ,
114- f"{ tmpdir } /test_cell_sam_wrapper_trt_compile" ,
115- args = {"precision" : precision , "dynamic_batchsize" : [1 , 1 , 1 ]},
116- )
114+ trt_compile (model , f"{ tmpdir } /test_cell_sam_wrapper_trt_compile" , args = {"precision" : precision })
117115 self .assertIsNone (model ._trt_compiler .engine )
118116 trt_output = model (input_example )
119117 # Check that lazy TRT build succeeded
@@ -130,7 +128,7 @@ def test_vista3d(self, precision):
130128 model = trt_compile (
131129 model ,
132130 f"{ tmpdir } /test_vista3d_trt_compile" ,
133- args = {"precision" : precision , "dynamic_batchsize" : [1 , 1 , 1 ]},
131+ args = {"precision" : precision , "dynamic_batchsize" : [1 , 2 , 4 ]},
134132 submodule = ["image_encoder.encoder" , "class_head" ],
135133 )
136134 self .assertIsNotNone (model .image_encoder .encoder ._trt_compiler )
0 commit comments