Skip to content

Commit 9c4060b

Browse files
SigureMoHz188
authored andcommitted
add removed typing
1 parent 82919cf commit 9c4060b

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

python/paddle/distributed/communication/all_gather.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import TYPE_CHECKING
17+
from typing import TYPE_CHECKING, TypeVar
1818

1919
import numpy as np
2020

@@ -32,6 +32,8 @@
3232
from paddle.base.core import task
3333
from paddle.distributed.communication.group import Group
3434

35+
_T = TypeVar("_T")
36+
3537

3638
def all_gather(
3739
tensor_list: list[Tensor],
@@ -82,7 +84,9 @@ def all_gather(
8284
return stream.all_gather(tensor_list, tensor, group, sync_op)
8385

8486

85-
def all_gather_object(object_list: list, obj, group: Group = None) -> None:
87+
def all_gather_object(
88+
object_list: list[_T] | list[None], obj: _T, group: Group = None
89+
) -> None:
8690
"""
8791
8892
Gather picklable objects from all participators and all get the result. Similar to all_gather(), but python object can be passed in.
@@ -106,7 +110,7 @@ def all_gather_object(object_list: list, obj, group: Group = None) -> None:
106110
>>> import paddle.distributed as dist
107111
108112
>>> dist.init_parallel_env()
109-
>>> object_list = [None for _ in dist.get_world_size()] # type: ignore
113+
>>> object_list = [None for _ in range(dist.get_world_size())]
110114
>>> if dist.get_rank() == 0:
111115
... obj = {"foo": [1, 2, 3]}
112116
>>> else:

0 commit comments

Comments
 (0)