@@ -1112,6 +1112,38 @@ def _check_tensor_list(param, param_name) -> None:
11121112 )
11131113
11141114
1115+ def _group_or_default_group (group : Optional [ProcessGroup ] = None ) -> ProcessGroup :
1116+ if group is None or group is GroupMember .WORLD :
1117+ group = _get_default_group ()
1118+ return group
1119+
1120+
1121+ def _canonicalize_group_rank (
1122+ group : ProcessGroup ,
1123+ global_rank : Optional [int ] = None ,
1124+ group_rank : Optional [int ] = None ,
1125+ ) -> int :
1126+ """
1127+ Helper method to take _either_ a global rank or a group rank and produce a group rank.
1128+ """
1129+ if group_rank is not None :
1130+ if global_rank is not None :
1131+ raise ValueError ("Can't specify both group_rank and global_rank" )
1132+ else :
1133+ if global_rank is None :
1134+ raise ValueError ("Must specify global_rank or group_rank" )
1135+ group_rank = get_group_rank (group , global_rank )
1136+ return group_rank
1137+
1138+
1139+ def _check_not_self_rank (group : ProcessGroup , rank : int , rank_type : str ):
1140+ if group .rank () == rank :
1141+ raise ValueError (
1142+ f"Invalid { rank_type } rank: { rank_type } rank should not be the same as "
1143+ "the rank of the current process."
1144+ )
1145+
1146+
11151147def _as_iterable (obj ) -> collections .abc .Iterable :
11161148 return obj if isinstance (obj , list ) else (obj ,)
11171149
@@ -2217,7 +2249,11 @@ def get_world_size(group: Optional[ProcessGroup] = None) -> int:
22172249
22182250
22192251def isend (
2220- tensor : torch .Tensor , dst : int , group : Optional [ProcessGroup ] = None , tag : int = 0
2252+ tensor : torch .Tensor ,
2253+ dst : Optional [int ] = None ,
2254+ group : Optional [ProcessGroup ] = None ,
2255+ tag : int = 0 ,
2256+ group_dst : Optional [int ] = None ,
22212257) -> Optional [Work ]:
22222258 """
22232259 Send a tensor asynchronously.
@@ -2229,18 +2265,23 @@ def isend(
22292265 .. warning::
22302266 ``tag`` is not supported with the NCCL backend.
22312267
2268+ Unlike send, which is blocking, isend allows src == dst rank, i.e. send to self.
2269+
22322270 Args:
22332271 tensor (Tensor): Tensor to send.
22342272 dst (int): Destination rank on global process group (regardless of ``group`` argument)
22352273 group (ProcessGroup, optional): The process group to work on. If None,
22362274 the default process group will be used.
22372275 tag (int, optional): Tag to match send with remote recv
2276+ group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst``
22382277
22392278 Returns:
22402279 A distributed request object.
22412280 None, if not part of the group
22422281
22432282 """
2283+ group = _group_or_default_group (group )
2284+ group_dst = _canonicalize_group_rank (group , dst , group_dst )
22442285 _check_single_tensor (tensor , "tensor" )
22452286 if _rank_not_in_group (group ):
22462287 _warn_not_in_group ("isend" )
@@ -2249,34 +2290,32 @@ def isend(
22492290 if tensor .is_complex ():
22502291 tensor = torch .view_as_real (tensor )
22512292
2252- if group is None or group is GroupMember .WORLD :
2253- pg = _get_default_group ()
2254- else :
2255- pg = group
2256- dst = get_group_rank (pg , dst )
2257-
2258- return pg .send ([tensor ], dst , tag )
2293+ return group .send ([tensor ], group_dst , tag )
22592294
22602295
22612296def irecv (
22622297 tensor : torch .Tensor ,
22632298 src : Optional [int ] = None ,
22642299 group : Optional [ProcessGroup ] = None ,
22652300 tag : int = 0 ,
2301+ group_src : Optional [int ] = None ,
22662302) -> Optional [Work ]:
22672303 """
22682304 Receives a tensor asynchronously.
22692305
22702306 .. warning::
22712307 ``tag`` is not supported with the NCCL backend.
22722308
2309+ Unlike recv, which is blocking, irecv allows src == dst rank, i.e. recv from self.
2310+
22732311 Args:
22742312 tensor (Tensor): Tensor to fill with received data.
22752313 src (int, optional): Source rank on global process group (regardless of ``group`` argument).
22762314 Will receive from any process if unspecified.
22772315 group (ProcessGroup, optional): The process group to work on. If None,
22782316 the default process group will be used.
22792317 tag (int, optional): Tag to match recv with remote send
2318+ group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``.
22802319
22812320 Returns:
22822321 A distributed request object.
@@ -2291,24 +2330,21 @@ def irecv(
22912330 if tensor .is_complex ():
22922331 tensor = torch .view_as_real (tensor )
22932332
2294- if group is None or group is GroupMember .WORLD :
2295- pg = _get_default_group ()
2296- else :
2297- pg = group
2298-
2299- if src is None :
2300- return pg .recv_anysource ([tensor ], tag )
2333+ group = _group_or_default_group (group )
2334+ if src is None and group_src is None :
2335+ return group .recv_anysource ([tensor ], tag )
23012336 else :
2302- if pg is GroupMember .WORLD :
2303- return pg .recv ([tensor ], src , tag )
2304- else :
2305- group_src_rank = get_group_rank (pg , src )
2306- return pg .recv ([tensor ], group_src_rank , tag )
2337+ group_src = _canonicalize_group_rank (group , src , group_src )
2338+ return group .recv ([tensor ], group_src , tag )
23072339
23082340
23092341@_exception_logger
23102342def send (
2311- tensor : torch .Tensor , dst : int , group : Optional [ProcessGroup ] = None , tag : int = 0
2343+ tensor : torch .Tensor ,
2344+ dst : Optional [int ] = None ,
2345+ group : Optional [ProcessGroup ] = None ,
2346+ tag : int = 0 ,
2347+ group_dst : Optional [int ] = None ,
23122348) -> None :
23132349 """
23142350 Send a tensor synchronously.
@@ -2323,14 +2359,12 @@ def send(
23232359 group (ProcessGroup, optional): The process group to work on. If None,
23242360 the default process group will be used.
23252361 tag (int, optional): Tag to match send with remote recv
2362+ group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst``.
23262363
23272364 """
2328- if get_rank () == dst :
2329- raise ValueError (
2330- "Invalid destination rank: destination rank should not be the same as "
2331- "the rank of the current process."
2332- )
2333-
2365+ group = _group_or_default_group (group )
2366+ group_dst = _canonicalize_group_rank (group , dst , group_dst )
2367+ _check_not_self_rank (group , group_dst , "destination" )
23342368 _check_single_tensor (tensor , "tensor" )
23352369 if _rank_not_in_group (group ):
23362370 _warn_not_in_group ("send" )
@@ -2339,12 +2373,7 @@ def send(
23392373 if tensor .is_complex ():
23402374 tensor = torch .view_as_real (tensor )
23412375
2342- if group is None or group is GroupMember .WORLD :
2343- default_pg = _get_default_group ()
2344- default_pg .send ([tensor ], dst , tag ).wait ()
2345- else :
2346- group_dst_rank = get_group_rank (group , dst )
2347- group .send ([tensor ], group_dst_rank , tag ).wait ()
2376+ group .send ([tensor ], group_dst , tag ).wait ()
23482377
23492378
23502379@_exception_logger
@@ -2353,6 +2382,7 @@ def recv(
23532382 src : Optional [int ] = None ,
23542383 group : Optional [ProcessGroup ] = None ,
23552384 tag : int = 0 ,
2385+ group_src : Optional [int ] = None ,
23562386) -> int :
23572387 """
23582388 Receives a tensor synchronously.
@@ -2367,7 +2397,7 @@ def recv(
23672397 group (ProcessGroup, optional): The process group to work on. If None,
23682398 the default process group will be used.
23692399 tag (int, optional): Tag to match recv with remote send
2370-
2400+ group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``.
23712401 Returns:
23722402 Sender rank
23732403 -1, if not part of the group
@@ -2381,23 +2411,18 @@ def recv(
23812411 if tensor .is_complex ():
23822412 tensor = torch .view_as_real (tensor )
23832413
2384- pg = group or _get_default_group ( )
2414+ group = _group_or_default_group ( group )
23852415
2386- if src is None :
2387- work = pg .recv_anysource ([tensor ], tag )
2416+ if src is None and group_src is None :
2417+ work = group .recv_anysource ([tensor ], tag )
23882418 work .wait ()
23892419 src_rank = work ._source_rank ()
2390- if group is None or group is GroupMember .WORLD :
2391- return src_rank
2392- else :
2393- return get_global_rank (pg , src_rank )
2420+ return get_global_rank (group , src_rank )
23942421 else :
2395- if group is None or group is GroupMember .WORLD :
2396- pg .recv ([tensor ], src , tag ).wait ()
2397- else :
2398- group_src_rank = get_group_rank (pg , src )
2399- pg .recv ([tensor ], group_src_rank , tag ).wait ()
2400- return src
2422+ group_src = _canonicalize_group_rank (group , src , group_src )
2423+ _check_not_self_rank (group , group_src , "source" )
2424+ group .recv ([tensor ], group_src , tag ).wait ()
2425+ return get_global_rank (group , group_src )
24012426
24022427
24032428class _IllegalWork (Work ):
0 commit comments