Skip to content

Commit a1d243f

Browse files
committed
Fixed mypy
Signed-off-by: Boris Fomitchev <[email protected]>
1 parent c7a1fb7 commit a1d243f

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

monai/networks/trt_compiler.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __init__(self, plan_path, logger=None):
125125
self.output_names = []
126126
self.dtypes = []
127127
self.cur_profile = 0
128+
self.input_table = {}
128129
dtype_dict = trt_to_torch_dtype_dict()
129130
for idx in range(self.engine.num_io_tensors):
130131
binding = self.engine[idx]
@@ -241,7 +242,7 @@ def unroll_input(input_names, input_example):
241242

242243

243244
def parse_groups(
244-
ret: List[torch.Tensor], output_lists: List[int]
245+
ret: List[torch.Tensor], output_lists: List[List[int]]
245246
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]], ...]:
246247
"""
247248
Implements parsing of 'output_lists' arg of trt_compile().
@@ -260,36 +261,34 @@ def parse_groups(
260261
Tuple of Union[torch.Tensor, List[torch.Tensor]], according to the grouping in output_lists
261262
262263
"""
263-
groups = []
264+
groups:Tuple[Union[torch.Tensor, List[torch.Tensor]], ...] = tuple()
264265
cur = 0
265266
for l in range(len(output_lists)):
266267
gl = output_lists[l]
267268
assert len(gl) == 0 or len(gl) == 1
268269
if len(gl) == 0 or gl[0] == 0:
269-
groups.append(ret[cur])
270+
groups = (*groups, ret[cur])
270271
cur = cur + 1
271272
elif gl[0] > 0:
272-
groups.append(ret[cur : cur + gl[0]])
273+
groups = (*groups, ret[cur : cur + gl[0]])
273274
cur = cur + gl[0]
274275
elif gl[0] == -1:
275-
rev_groups = []
276+
rev_groups:Tuple[Union[torch.Tensor, List[torch.Tensor]], ...] = tuple()
276277
rcur = len(ret)
277278
for rl in range(len(output_lists) - 1, l, -1):
278279
rgl = output_lists[rl]
279280
assert len(rgl) == 0 or len(rgl) == 1
280281
if len(rgl) == 0 or rgl[0] == 0:
281282
rcur = rcur - 1
282-
rev_groups.append(ret[rcur])
283+
rev_groups=(*rev_groups, ret[rcur])
283284
elif rgl[0] > 0:
284285
rcur = rcur - rgl[0]
285-
rev_groups.append(ret[rcur : rcur + rgl[0]])
286+
rev_groups=(*rev_groups, ret[rcur : rcur + rgl[0]])
286287
else:
287288
raise ValueError("Two -1 lists in output")
288-
groups.append(ret[cur:rcur])
289-
rev_groups.reverse()
290-
groups.extend(rev_groups)
289+
groups = (*groups, ret[cur:rcur], *rev_groups[::-1])
291290
break
292-
return tuple(groups)
291+
return groups
293292

294293

295294
class TrtCompiler:

0 commit comments

Comments
 (0)