diff --git a/paddle/phi/core/distributed/check/static_check.cc b/paddle/phi/core/distributed/check/static_check.cc index b6f992a0592912..637386668cf32b 100644 --- a/paddle/phi/core/distributed/check/static_check.cc +++ b/paddle/phi/core/distributed/check/static_check.cc @@ -183,14 +183,33 @@ void CommStaticCheck::GatherLikeShape(const DenseTensor& out_tensor, int cur_rank, int world_size, phi::AllocationType place) { - CheckShape(out_tensor, - in_tensor, - dst_rank, - cur_rank, - world_size, - /*out_size_factor*/ 1, - /*in_size_factor*/ world_size, - place); + CheckRank(dst_rank, world_size); + CheckRank(cur_rank, world_size); + + CheckPlace(out_tensor, in_tensor, place); + CheckDataType(out_tensor, in_tensor); + + CheckGatherShape(out_tensor); + CheckGatherShape(in_tensor); + int64_t out_size = out_tensor.numel(), in_size = in_tensor.numel(); + PADDLE_ENFORCE_EQ( + out_size, + in_size * world_size, + common::errors::InvalidArgument( + "Input and output tensors should have matching sizes. " + "out_size=%ld, out_size_factor=%d, in_size=%ld, in_size_factor=%d", + out_size, + 1, + in_size, + world_size)); +} + +void CommStaticCheck::CheckGatherShape(const phi::DenseTensor& tensor) { + PADDLE_ENFORCE_GE( + tensor.numel(), + 0, + common::errors::InvalidArgument("Size of tensor should be greater equal " + "than 0 in gather-liked communication.")); } } // namespace phi::distributed diff --git a/paddle/phi/core/distributed/check/static_check.h b/paddle/phi/core/distributed/check/static_check.h index e3c2e1ee7fb68e..476216c0c0375a 100644 --- a/paddle/phi/core/distributed/check/static_check.h +++ b/paddle/phi/core/distributed/check/static_check.h @@ -81,6 +81,8 @@ struct CommStaticCheck { int cur_rank, int world_size, phi::AllocationType place = phi::AllocationType::GPU); + + static void CheckGatherShape(const phi::DenseTensor& tensor); }; } // namespace distributed diff --git a/python/paddle/distributed/communication/all_gather.py b/python/paddle/distributed/communication/all_gather.py index 407f8f3f624234..deda5303c90e99 100644 --- a/python/paddle/distributed/communication/all_gather.py +++ b/python/paddle/distributed/communication/all_gather.py @@ -85,7 +85,7 @@ def all_gather( def all_gather_object( - object_list: list[_T], obj: _T, group: Group = None + object_list: list[_T] | list[None], obj: _T, group: Group = None ) -> None: """ @@ -110,7 +110,7 @@ def all_gather_object( >>> import paddle.distributed as dist >>> dist.init_parallel_env() - >>> object_list = [] # type: ignore + >>> object_list = [None for _ in range(dist.get_world_size())] >>> if dist.get_rank() == 0: ... obj = {"foo": [1, 2, 3]} >>> else: @@ -139,7 +139,9 @@ def all_gather_object( tensor_list = [] all_gather(tensor_list, input_tensor, group) + # Ensure object_list has enough slots for all gathered objects + while len(object_list) < len(tensor_list): + object_list.append(None) + for i, tensor in enumerate(tensor_list): - object_list.append( - convert_tensor_to_object(tensor, list_len_of_tensor[i]) - ) + object_list[i] = convert_tensor_to_object(tensor, list_len_of_tensor[i])