Skip to content

Commit

Permalink
str
Browse files Browse the repository at this point in the history
Signed-off-by: Xiaowei Jiang <[email protected]>
  • Loading branch information
xwjiang2010 committed Jun 27, 2024
1 parent c0bba5e commit 4c324ad
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class GraphCaptureContext:


def _split_tensor_dict(
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
Expand Down Expand Up @@ -473,11 +473,11 @@ def recv_object(self, src: int) -> Any:

def broadcast_tensor_dict(
self,
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
Expand Down Expand Up @@ -558,9 +558,9 @@ def broadcast_tensor_dict(

def send_tensor_dict(
self,
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
Expand Down Expand Up @@ -599,7 +599,7 @@ def send_tensor_dict(
def recv_tensor_dict(
self,
src: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
Expand All @@ -615,7 +615,7 @@ def recv_tensor_dict(
assert src < self.world_size, f"Invalid src rank ({src})"

recv_metadata_list = self.recv_object(src=src)
tensor_dict : Dict[Any, Any] = {}
tensor_dict : Dict[str, Any] = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
Expand Down

0 comments on commit 4c324ad

Please sign in to comment.