@@ -261,7 +261,7 @@ def parse_groups(
261261 Tuple of Union[torch.Tensor, List[torch.Tensor]], according to the grouping in output_lists
262262
263263 """
264- groups :Tuple [Union [torch .Tensor , List [torch .Tensor ]], ...] = tuple ()
264+ groups : Tuple [Union [torch .Tensor , List [torch .Tensor ]], ...] = tuple ()
265265 cur = 0
266266 for l in range (len (output_lists )):
267267 gl = output_lists [l ]
@@ -273,17 +273,17 @@ def parse_groups(
273273 groups = (* groups , ret [cur : cur + gl [0 ]])
274274 cur = cur + gl [0 ]
275275 elif gl [0 ] == - 1 :
276- rev_groups :Tuple [Union [torch .Tensor , List [torch .Tensor ]], ...] = tuple ()
276+ rev_groups : Tuple [Union [torch .Tensor , List [torch .Tensor ]], ...] = tuple ()
277277 rcur = len (ret )
278278 for rl in range (len (output_lists ) - 1 , l , - 1 ):
279279 rgl = output_lists [rl ]
280280 assert len (rgl ) == 0 or len (rgl ) == 1
281281 if len (rgl ) == 0 or rgl [0 ] == 0 :
282282 rcur = rcur - 1
283- rev_groups = (* rev_groups , ret [rcur ])
283+ rev_groups = (* rev_groups , ret [rcur ])
284284 elif rgl [0 ] > 0 :
285285 rcur = rcur - rgl [0 ]
286- rev_groups = (* rev_groups , ret [rcur : rcur + rgl [0 ]])
286+ rev_groups = (* rev_groups , ret [rcur : rcur + rgl [0 ]])
287287 else :
288288 raise ValueError ("Two -1 lists in output" )
289289 groups = (* groups , ret [cur :rcur ], * rev_groups [::- 1 ])
0 commit comments