@@ -469,57 +469,61 @@ class P2POp:
469469 The type of ``op`` is either ``torch.distributed.isend`` or
470470 ``torch.distributed.irecv``.
471471 tensor (Tensor): Tensor to send or receive.
472- peer (int): Destination or source rank.
472+ peer (int, optional ): Destination or source rank.
473473 group (ProcessGroup, optional): The process group to work on. If None,
474474 the default process group will be used.
475475 tag (int, optional): Tag to match send with recv.
476+ group_peer (int, optional): Destination or source rank.
476477 """
477478
478479 def __init__ (
479480 self ,
480481 op : Callable ,
481482 tensor : torch .Tensor ,
482- peer : int ,
483+ peer : Optional [ int ] = None ,
483484 group : Optional [ProcessGroup ] = None ,
484485 tag : int = 0 ,
486+ group_peer : Optional [int ] = None ,
485487 ):
486488 """Init."""
487489 self .op = op
488490 self .tensor = tensor
489- self .peer = peer
490- self .group = group
491+ self .group = _group_or_default_group (group )
492+ self .peer = _canonicalize_group_rank (
493+ self .group , peer , group_peer , return_global = True
494+ )
491495 self .tag = tag
496+ self .group_peer = _canonicalize_group_rank (self .group , peer , group_peer )
492497
493498 def __new__ (
494499 cls ,
495500 op : Callable ,
496501 tensor : torch .Tensor ,
497- peer : int ,
502+ peer : Optional [ int ] = None ,
498503 group : Optional [ProcessGroup ] = None ,
499504 tag : int = 0 ,
505+ group_peer : Optional [int ] = None ,
500506 ):
501507 """Create and return a new instance of the class."""
502508 _check_op (op )
503509 _check_single_tensor (tensor , "tensor" )
510+
504511 return object .__new__ (cls )
505512
506513 def __repr__ (self ):
507514 my_group_rank = get_rank (self .group )
508- peer_group_rank = (
509- get_group_rank (self .group , self .peer ) if self .group else self .peer
510- )
511515 op_name = self .op .__name__
512516 group_name = self .group .group_name if self .group else "default_pg"
513517 if "send" in op_name :
514518 s = my_group_rank
515- d = peer_group_rank
519+ d = self . group_peer
516520 elif "recv" in op_name :
517- s = peer_group_rank
521+ s = self . group_peer
518522 d = my_group_rank
519523 else :
520524 return super ().__repr__ ()
521525
522- return f"P2POp({ op_name } pg={ group_name } , s ={ s } , d ={ d } , { self .tensor .shape } , { self .tensor .dtype } )"
526+ return f"P2POp({ op_name } pg={ group_name } , group_src ={ s } , group_dst ={ d } , { self .tensor .shape } , { self .tensor .dtype } )"
523527
524528
525529class _CollOp :
@@ -2545,7 +2549,7 @@ def _coalescing_manager(
25452549 work .wait () # type: ignore[possibly-undefined]
25462550
25472551
2548- def batch_isend_irecv (p2p_op_list ) :
2552+ def batch_isend_irecv (p2p_op_list : List [ P2POp ]) -> List [ Work ] :
25492553 """
25502554 Send or Receive a batch of tensors asynchronously and return a list of requests.
25512555
@@ -2588,17 +2592,33 @@ def batch_isend_irecv(p2p_op_list):
25882592 _check_p2p_op_list (p2p_op_list )
25892593 group = p2p_op_list [0 ].group
25902594 device = p2p_op_list [0 ].tensor .device
2595+
2596+ def peer_kwarg (op : P2POp ) -> Dict [str , int ]:
2597+ key = "group_dst" if op .op == isend else "group_src"
2598+ return {key : op .group_peer }
2599+
25912600 if device .type == "cuda" :
25922601 # NCCL style coalescing
25932602 with _coalescing_manager (group , device , async_ops = True ) as cm :
25942603 for p2p_op in p2p_op_list :
2595- p2p_op .op (p2p_op .tensor , p2p_op .peer , p2p_op .group , p2p_op .tag )
2604+ p2p_op .op (
2605+ p2p_op .tensor ,
2606+ group = p2p_op .group ,
2607+ tag = p2p_op .tag ,
2608+ ** peer_kwarg (p2p_op ),
2609+ )
2610+
25962611 return cm .works
25972612 else :
25982613 # Backward support for Gloo
25992614 reqs = []
26002615 for p2p_op in p2p_op_list :
2601- work = p2p_op .op (p2p_op .tensor , p2p_op .peer , p2p_op .group , p2p_op .tag )
2616+ work = p2p_op .op (
2617+ p2p_op .tensor ,
2618+ group = p2p_op .group ,
2619+ tag = p2p_op .tag ,
2620+ ** peer_kwarg (p2p_op ),
2621+ )
26022622 if work :
26032623 reqs .append (work )
26042624 return reqs
0 commit comments