@@ -125,6 +125,7 @@ def __init__(self, plan_path, logger=None):
125125 self .output_names = []
126126 self .dtypes = []
127127 self .cur_profile = 0
128+ self .input_table = {}
128129 dtype_dict = trt_to_torch_dtype_dict ()
129130 for idx in range (self .engine .num_io_tensors ):
130131 binding = self .engine [idx ]
@@ -241,7 +242,7 @@ def unroll_input(input_names, input_example):
241242
242243
243244def parse_groups (
244- ret : List [torch .Tensor ], output_lists : List [int ]
245+ ret : List [torch .Tensor ], output_lists : List [List [ int ] ]
245246) -> Tuple [Union [torch .Tensor , List [torch .Tensor ]], ...]:
246247 """
247248 Implements parsing of 'output_lists' arg of trt_compile().
@@ -260,36 +261,34 @@ def parse_groups(
260261 Tuple of Union[torch.Tensor, List[torch.Tensor]], according to the grouping in output_lists
261262
262263 """
263- groups = []
264+ groups : Tuple [ Union [ torch . Tensor , List [ torch . Tensor ]], ...] = tuple ()
264265 cur = 0
265266 for l in range (len (output_lists )):
266267 gl = output_lists [l ]
267268 assert len (gl ) == 0 or len (gl ) == 1
268269 if len (gl ) == 0 or gl [0 ] == 0 :
269- groups . append ( ret [cur ])
270+ groups = ( * groups , ret [cur ])
270271 cur = cur + 1
271272 elif gl [0 ] > 0 :
272- groups . append ( ret [cur : cur + gl [0 ]])
273+ groups = ( * groups , ret [cur : cur + gl [0 ]])
273274 cur = cur + gl [0 ]
274275 elif gl [0 ] == - 1 :
275- rev_groups = []
276+ rev_groups : Tuple [ Union [ torch . Tensor , List [ torch . Tensor ]], ...] = tuple ()
276277 rcur = len (ret )
277278 for rl in range (len (output_lists ) - 1 , l , - 1 ):
278279 rgl = output_lists [rl ]
279280 assert len (rgl ) == 0 or len (rgl ) == 1
280281 if len (rgl ) == 0 or rgl [0 ] == 0 :
281282 rcur = rcur - 1
282- rev_groups . append ( ret [rcur ])
283+ rev_groups = ( * rev_groups , ret [rcur ])
283284 elif rgl [0 ] > 0 :
284285 rcur = rcur - rgl [0 ]
285- rev_groups . append ( ret [rcur : rcur + rgl [0 ]])
286+ rev_groups = ( * rev_groups , ret [rcur : rcur + rgl [0 ]])
286287 else :
287288 raise ValueError ("Two -1 lists in output" )
288- groups .append (ret [cur :rcur ])
289- rev_groups .reverse ()
290- groups .extend (rev_groups )
289+ groups = (* groups , ret [cur :rcur ], * rev_groups [::- 1 ])
291290 break
292- return tuple ( groups )
291+ return groups
293292
294293
295294class TrtCompiler :
0 commit comments