@@ -469,57 +469,61 @@ class P2POp:
469469 The type of ``op`` is either ``torch.distributed.isend`` or
470470 ``torch.distributed.irecv``.
471471 tensor (Tensor): Tensor to send or receive.
472- peer (int): Destination or source rank.
472+ peer (int, optional ): Destination or source rank.
473473 group (ProcessGroup, optional): The process group to work on. If None,
474474 the default process group will be used.
475475 tag (int, optional): Tag to match send with recv.
476+ group_peer (int, optional): Destination or source rank.
476477 """
477478
478479 def __init__ (
479480 self ,
480481 op : Callable ,
481482 tensor : torch .Tensor ,
482- peer : int ,
483+ peer : Optional [ int ] = None ,
483484 group : Optional [ProcessGroup ] = None ,
484485 tag : int = 0 ,
486+ group_peer : Optional [int ] = None ,
485487 ):
486488 """Init."""
487489 self .op = op
488490 self .tensor = tensor
489- self .peer = peer
490- self .group = group
491+ self .group = _group_or_default_group (group )
492+ self .peer = _canonicalize_group_rank (
493+ self .group , peer , group_peer , return_global = True
494+ )
491495 self .tag = tag
496+ self .group_peer = _canonicalize_group_rank (self .group , peer , group_peer )
492497
493498 def __new__ (
494499 cls ,
495500 op : Callable ,
496501 tensor : torch .Tensor ,
497- peer : int ,
502+ peer : Optional [ int ] = None ,
498503 group : Optional [ProcessGroup ] = None ,
499504 tag : int = 0 ,
505+ group_peer : Optional [int ] = None ,
500506 ):
501507 """Create and return a new instance of the class."""
502508 _check_op (op )
503509 _check_single_tensor (tensor , "tensor" )
510+
504511 return object .__new__ (cls )
505512
506513 def __repr__ (self ):
507514 my_group_rank = get_rank (self .group )
508- peer_group_rank = (
509- get_group_rank (self .group , self .peer ) if self .group else self .peer
510- )
511515 op_name = self .op .__name__
512516 group_name = self .group .group_name if self .group else "default_pg"
513517 if "send" in op_name :
514518 s = my_group_rank
515- d = peer_group_rank
519+ d = self . group_peer
516520 elif "recv" in op_name :
517- s = peer_group_rank
521+ s = self . group_peer
518522 d = my_group_rank
519523 else :
520524 return super ().__repr__ ()
521525
522- return f"P2POp({ op_name } pg={ group_name } , s ={ s } , d ={ d } , { self .tensor .shape } , { self .tensor .dtype } )"
526+ return f"P2POp({ op_name } pg={ group_name } , group_src ={ s } , group_dst ={ d } , { self .tensor .shape } , { self .tensor .dtype } )"
523527
524528
525529class _CollOp :
@@ -2546,7 +2550,7 @@ def _coalescing_manager(
25462550 work .wait () # type: ignore[possibly-undefined]
25472551
25482552
2549- def batch_isend_irecv (p2p_op_list ) :
2553+ def batch_isend_irecv (p2p_op_list : List [ P2POp ]) -> List [ Work ] :
25502554 """
25512555 Send or Receive a batch of tensors asynchronously and return a list of requests.
25522556
@@ -2589,17 +2593,33 @@ def batch_isend_irecv(p2p_op_list):
25892593 _check_p2p_op_list (p2p_op_list )
25902594 group = p2p_op_list [0 ].group
25912595 device = p2p_op_list [0 ].tensor .device
2596+
2597+ def peer_kwarg (op : P2POp ) -> Dict [str , int ]:
2598+ key = "group_dst" if op .op == isend else "group_src"
2599+ return {key : op .group_peer }
2600+
25922601 if device .type == "cuda" :
25932602 # NCCL style coalescing
25942603 with _coalescing_manager (group , device , async_ops = True ) as cm :
25952604 for p2p_op in p2p_op_list :
2596- p2p_op .op (p2p_op .tensor , p2p_op .peer , p2p_op .group , p2p_op .tag )
2605+ p2p_op .op (
2606+ p2p_op .tensor ,
2607+ group = p2p_op .group ,
2608+ tag = p2p_op .tag ,
2609+ ** peer_kwarg (p2p_op ),
2610+ )
2611+
25972612 return cm .works
25982613 else :
25992614 # Backward support for Gloo
26002615 reqs = []
26012616 for p2p_op in p2p_op_list :
2602- work = p2p_op .op (p2p_op .tensor , p2p_op .peer , p2p_op .group , p2p_op .tag )
2617+ work = p2p_op .op (
2618+ p2p_op .tensor ,
2619+ group = p2p_op .group ,
2620+ tag = p2p_op .tag ,
2621+ ** peer_kwarg (p2p_op ),
2622+ )
26032623 if work :
26042624 reqs .append (work )
26052625 return reqs
0 commit comments