@@ -3747,7 +3747,8 @@ def _init_two_pg2_subgroups(self, world_size: int = 4):
37473747
37483748    @requires_nccl () 
37493749    @skip_if_lt_x_gpu (4 ) 
3750-     def  test_gather_subgroup (self ):
3750+     @parametrize ("group_rank" , [True , False ]) 
3751+     def  test_gather_subgroup (self , group_rank ):
37513752        world_size  =  4 
37523753        if  self .rank  >=  world_size :
37533754            # just easier to write the test for exactly 4 gpus, even if this test class increased to 8gpu later 
@@ -3758,28 +3759,48 @@ def test_gather_subgroup(self):
37583759        input  =  torch .ones ((10 ,), device = device ) *  self .rank 
37593760        if  self .rank  ==  0  or  self .rank  ==  2 :
37603761            gather_list  =  [torch .empty_like (input ) for  _  in  range (subgroup .size ())]
3761-             torch .distributed .gather (
3762-                 input ,
3763-                 gather_list = gather_list ,
3764-                 dst = self .rank ,
3765-                 group = subgroup ,
3766-                 async_op = False ,
3767-             )
3762+             if  group_rank :
3763+                 # global_dst=0 group_dst=0 my_global_rank=2 gather_list is not None=True 
3764+                 torch .distributed .gather (
3765+                     input ,
3766+                     gather_list = gather_list ,
3767+                     group_dst = 0 ,
3768+                     group = subgroup ,
3769+                     async_op = False ,
3770+                 )
3771+             else :
3772+                 torch .distributed .gather (
3773+                     input ,
3774+                     gather_list = gather_list ,
3775+                     dst = self .rank ,
3776+                     group = subgroup ,
3777+                     async_op = False ,
3778+                 )
37683779            for  src  in  range (len (gather_list )):
37693780                expected  =  (torch .ones_like (input ) *  self .rank ) +  src 
37703781                self .assertEqual (gather_list [src ], expected )
37713782        else :
3772-             torch .distributed .gather (
3773-                 input ,
3774-                 gather_list = None ,
3775-                 dst = self .rank  -  1 ,
3776-                 group = subgroup ,
3777-                 async_op = False ,
3778-             )
3783+             if  group_rank :
3784+                 torch .distributed .gather (
3785+                     input ,
3786+                     gather_list = None ,
3787+                     group_dst = 0 ,
3788+                     group = subgroup ,
3789+                     async_op = False ,
3790+                 )
3791+             else :
3792+                 torch .distributed .gather (
3793+                     input ,
3794+                     gather_list = None ,
3795+                     dst = self .rank  -  1 ,
3796+                     group = subgroup ,
3797+                     async_op = False ,
3798+                 )
37793799
37803800    @requires_nccl () 
37813801    @skip_if_lt_x_gpu (4 ) 
3782-     def  test_gather_object_subgroup (self ):
3802+     @parametrize ("group_rank" , [True , False ]) 
3803+     def  test_gather_object_subgroup (self , group_rank ):
37833804        world_size  =  4 
37843805        if  self .rank  >=  world_size :
37853806            # just easier to write the test for exactly 4 gpus, even if this test class increased to 8gpu later 
@@ -3797,15 +3818,25 @@ def test_gather_object_subgroup(self):
37973818            # another weird thing- what's the point of making me specify some empty objects in my list? 
37983819            # empty list should be valid imo.  (but it throws an error) 
37993820            gather_list  =  [{}, {}]
3800-             torch .distributed .gather_object (
3801-                 input , object_gather_list = gather_list , dst = self .rank , group = subgroup 
3802-             )
3821+             if  group_rank :
3822+                 torch .distributed .gather_object (
3823+                     input , object_gather_list = gather_list , group_dst = 0 , group = subgroup 
3824+                 )
3825+             else :
3826+                 torch .distributed .gather_object (
3827+                     input , object_gather_list = gather_list , dst = self .rank , group = subgroup 
3828+                 )
38033829            for  src  in  range (len (gather_list )):
38043830                self .assertEqual (gather_list [src ]["rank" ], self .rank  +  src )
38053831        else :
3806-             torch .distributed .gather_object (
3807-                 input , object_gather_list = None , dst = self .rank  -  1 , group = subgroup 
3808-             )
3832+             if  group_rank :
3833+                 torch .distributed .gather_object (
3834+                     input , object_gather_list = None , group_dst = 0 , group = subgroup 
3835+                 )
3836+             else :
3837+                 torch .distributed .gather_object (
3838+                     input , object_gather_list = None , dst = self .rank  -  1 , group = subgroup 
3839+                 )
38093840
38103841    @requires_nccl () 
38113842    @skip_if_lt_x_gpu (4 ) 
@@ -3931,7 +3962,8 @@ def test_broadcast_object_list_subgroup(self, set_device: SetDeviceMethod):
39313962
39323963    @requires_nccl () 
39333964    @skip_if_lt_x_gpu (4 ) 
3934-     def  test_scatter_subgroup (self ):
3965+     @parametrize ("group_rank" , [True , False ]) 
3966+     def  test_scatter_subgroup (self , group_rank ):
39353967        world_size  =  4 
39363968        if  self .rank  >=  world_size :
39373969            return 
@@ -3940,18 +3972,27 @@ def test_scatter_subgroup(self):
39403972        x  =  torch .empty ((10 ,), device = device )
39413973        expected  =  torch .ones ((10 ,), device = device ) *  self .rank 
39423974        if  self .rank  ==  0  or  self .rank  ==  2 :
3943-             c10d .scatter (x , scatter_list = None , src = self .rank  +  1 , group = subgroup )
3975+             if  group_rank :
3976+                 c10d .scatter (x , scatter_list = None , group_src = 1 , group = subgroup )
3977+             else :
3978+                 c10d .scatter (x , scatter_list = None , src = self .rank  +  1 , group = subgroup )
39443979        else :
39453980            scatter_list  =  [
39463981                torch .ones ((10 ,), device = device ) *  (self .rank  -  1 ),
39473982                torch .ones ((10 ,), device = device ) *  self .rank ,
39483983            ]
3949-             c10d .scatter (x , scatter_list = scatter_list , src = self .rank , group = subgroup )
3984+             if  group_rank :
3985+                 c10d .scatter (x , scatter_list = scatter_list , group_src = 1 , group = subgroup )
3986+             else :
3987+                 c10d .scatter (
3988+                     x , scatter_list = scatter_list , src = self .rank , group = subgroup 
3989+                 )
39503990        self .assertEqual (x , expected )
39513991
39523992    @requires_nccl () 
39533993    @skip_if_lt_x_gpu (4 ) 
3954-     def  test_scatter_object_list_subgroup (self ):
3994+     @parametrize ("group_rank" , [True , False ]) 
3995+     def  test_scatter_object_list_subgroup (self , group_rank ):
39553996        world_size  =  4 
39563997        if  self .rank  >=  world_size :
39573998            return 
@@ -3960,24 +4001,40 @@ def test_scatter_object_list_subgroup(self):
39604001        scatter_object_output_list  =  [None ]
39614002        expected  =  [{"rank" : self .rank }]
39624003        if  self .rank  ==  0  or  self .rank  ==  2 :
3963-             c10d .scatter_object_list (
3964-                 scatter_object_output_list = scatter_object_output_list ,
3965-                 scatter_object_input_list = None ,
3966-                 src = self .rank  +  1 ,
3967-                 group = subgroup ,
3968-             )
4004+             if  group_rank :
4005+                 c10d .scatter_object_list (
4006+                     scatter_object_output_list = scatter_object_output_list ,
4007+                     scatter_object_input_list = None ,
4008+                     group_src = 1 ,
4009+                     group = subgroup ,
4010+                 )
4011+             else :
4012+                 c10d .scatter_object_list (
4013+                     scatter_object_output_list = scatter_object_output_list ,
4014+                     scatter_object_input_list = None ,
4015+                     src = self .rank  +  1 ,
4016+                     group = subgroup ,
4017+                 )
39694018
39704019        else :
39714020            scatter_object_input_list  =  [
39724021                {"rank" : self .rank  -  1 },
39734022                {"rank" : self .rank },
39744023            ]
3975-             c10d .scatter_object_list (
3976-                 scatter_object_output_list = scatter_object_output_list ,
3977-                 scatter_object_input_list = scatter_object_input_list ,
3978-                 src = self .rank ,
3979-                 group = subgroup ,
3980-             )
4024+             if  group_rank :
4025+                 c10d .scatter_object_list (
4026+                     scatter_object_output_list = scatter_object_output_list ,
4027+                     scatter_object_input_list = scatter_object_input_list ,
4028+                     group_src = 1 ,
4029+                     group = subgroup ,
4030+                 )
4031+             else :
4032+                 c10d .scatter_object_list (
4033+                     scatter_object_output_list = scatter_object_output_list ,
4034+                     scatter_object_input_list = scatter_object_input_list ,
4035+                     src = self .rank ,
4036+                     group = subgroup ,
4037+                 )
39814038        self .assertEqual (scatter_object_output_list , expected )
39824039
39834040
0 commit comments