diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index 781a0e4d045bc7..de35a3e7e15506 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -899,7 +899,7 @@ def _build_comm_buffers( logger.info(f"Tensor Fusion Color {g_color} and Group {g_group}: ") var_groups = assign_group_by_size(params, group_size) opt_states_sizes = get_group_size(params, group_size) - for _, parameters in var_groups.items(): + for idx, parameters in var_groups.items(): buffer = FusedCommBuffer( group_idx, parameters, @@ -918,7 +918,7 @@ def _build_comm_buffers( self._slice_params[param.name].is_offload_opt = True # here group_size is parameter size (GB) # optimizer states(float32) size is 6 times as much as parameter(bfloat16) size - offload_buffer_size -= sum(opt_states_sizes) + offload_buffer_size -= opt_states_sizes[idx] else: for param in parameters: self._slice_params[param.name].is_offload_opt = False