@@ -61,7 +61,7 @@ def tearDown(self):
6161 if current_device != self .gpu_device :
6262 torch .cuda .set_device (self .gpu_device )
6363
64- @unittest .skipUnless (torch_trt_imported , "torch_tensorrt is required" )
64+ # @unittest.skipUnless(torch_trt_imported, "torch_tensorrt is required")
6565 def test_handler (self ):
6666 from ignite .engine import Engine
6767
@@ -74,7 +74,7 @@ def test_handler(self):
7474
7575 with tempfile .TemporaryDirectory () as tempdir :
7676 engine = Engine (lambda e , b : None )
77- args = {"method" : "torch_trt" }
77+ args = {"method" : "onnx" , "dynamic_batchsize" : [ 1 , 4 , 8 ] }
7878 TrtHandler (net1 , tempdir + "/trt_handler" , args = args ).attach (engine )
7979 engine .run ([0 ] * 8 , max_epochs = 1 )
8080 self .assertIsNotNone (net1 ._trt_compiler )
@@ -86,7 +86,11 @@ def test_lists(self):
8686 model = ListAdd ().cuda ()
8787
8888 with torch .no_grad (), tempfile .TemporaryDirectory () as tmpdir :
89- args = {"output_lists" : [[- 1 ], [2 ], []], "export_args" : {"dynamo" : False , "verbose" : True }}
89+ args = {
90+ "output_lists" : [[- 1 ], [2 ], []],
91+ "export_args" : {"dynamo" : False , "verbose" : True },
92+ "dynamic_batchsize" : [1 , 4 , 8 ],
93+ }
9094 x = torch .randn (1 , 16 ).to ("cuda" )
9195 y = torch .randn (1 , 16 ).to ("cuda" )
9296 z = torch .randn (1 , 16 ).to ("cuda" )
0 commit comments