diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 7e9e6981f3c..a12a4ec5c9a 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -57,7 +57,7 @@ from verl.utils.debug import marked_timer from verl.utils.metric import reduce_metrics from verl.utils.rollout_skip import RolloutSkip -from verl.utils.seqlen_balancing import calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance +from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance from verl.utils.torch_functional import masked_mean from verl.utils.tracking import ValidationGenerationsLogger @@ -914,35 +914,15 @@ def _stop_profiling(self, do_profile: bool) -> None: if self.use_rm: self.rm_wg.stop_profile() - def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen", keep_minibatch=False): + def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): """Reorder the data on single controller such that each dp rank gets similar total tokens""" attention_mask = batch.batch["attention_mask"] batch_size = attention_mask.shape[0] - global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,) - global_seqlen_lst = calculate_workload(global_seqlen_lst) + global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) world_size = self.actor_rollout_wg.world_size - if keep_minibatch: - # Decouple the DP balancing and mini-batching. - minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size") - minibatch_num = len(global_seqlen_lst) // minibatch_size - global_partition_lst = [[] for _ in range(world_size)] - for i in range(minibatch_num): - rearrange_minibatch_lst = get_seqlen_balanced_partitions( - global_seqlen_lst[i * minibatch_size : (i + 1) * minibatch_size], - k_partitions=world_size, - equal_size=True, - ) - for j, part in enumerate(rearrange_minibatch_lst): - global_partition_lst[j].extend([x + minibatch_size * i for x in part]) - else: - global_partition_lst = get_seqlen_balanced_partitions( - global_seqlen_lst, k_partitions=world_size, equal_size=True - ) - # Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel. - for idx, partition in enumerate(global_partition_lst): - partition.sort(key=lambda x: (global_seqlen_lst[x], x)) - ordered_partition = partition[::2] + partition[1::2][::-1] - global_partition_lst[idx] = ordered_partition + global_partition_lst = get_seqlen_balanced_partitions( + global_seqlen_lst, k_partitions=world_size, equal_size=True + ) # reorder based on index. The data will be automatically equally partitioned by dispatch function global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) batch.reorder(global_idx) @@ -1123,6 +1103,7 @@ def fit(self): # NOTE: This usually changes the order of data in the `batch`, # which won't affect the advantage calculation (since it's based on uid), # but might affect the loss calculation (due to the change of mini-batching). + # TODO: Decouple the DP balancing and mini-batching. if self.config.trainer.balance_batch: self._balance_batch(batch, metrics=metrics) diff --git a/verl/utils/seqlen_balancing.py b/verl/utils/seqlen_balancing.py index bc5588f7ac6..5354d5114e6 100644 --- a/verl/utils/seqlen_balancing.py +++ b/verl/utils/seqlen_balancing.py @@ -24,16 +24,6 @@ from verl.utils.device import get_device_name -def calculate_workload(seqlen_list: list[int]): - """ - Calculate the workload for a dense transformer block based on sequence length. - FLOPs = 12 * hidden_size^2 * seqlen + 2 * hidden_size * seqlen^2 - Hardcodes the constants by a 7B model (hidden_size=4096), - so the FLOPs are propotional to (6 * 4096 * seqlen + seqlen^2). - """ - return 24576 * seqlen_list + seqlen_list**2 - - def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool): # see: https://en.wikipedia.org/wiki/Largest_differencing_method class Set: @@ -308,22 +298,20 @@ def rearrange_micro_batches( if num_batches_divided_by is not None: num_micro_batches = roundup_divisible(num_micro_batches, num_batches_divided_by) + seq_len_effective = seq_len_effective.tolist() assert num_micro_batches <= len(seq_len_effective) - workloads = calculate_workload(seq_len_effective) - micro_bsz_idx = get_seqlen_balanced_partitions(workloads, num_micro_batches, equal_size=False) + micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False) if use_dynamic_bsz_balance: # Use the sum of squared sequence lengths to approximate attention computation workload micro_bsz_idx.sort( key=lambda partition: ( - sum(workloads[idx] for idx in partition), - partition[0] if partition else 0, + sum(seq_len_effective[idx] ** 2 for idx in partition), + min(partition) if partition else 0, ), reverse=True, ) - # Place smaller micro-batches at both ends to reduce the bubbles exposed during the warm-up and cool-down. - micro_bsz_idx = micro_bsz_idx[::2][::-1] + micro_bsz_idx[1::2] micro_batches = []