diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 0e3b1b77c5f..67da97d2c24 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -56,7 +56,7 @@ from verl.utils.debug import marked_timer from verl.utils.metric import reduce_metrics from verl.utils.rollout_skip import RolloutSkip -from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance +from verl.utils.seqlen_balancing import calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance from verl.utils.torch_functional import masked_mean from verl.utils.tracking import ValidationGenerationsLogger @@ -900,15 +900,35 @@ def _stop_profiling(self, do_profile: bool) -> None: if self.use_rm: self.rm_wg.stop_profile() - def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): + def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen", keep_minibatch=False): """Reorder the data on single controller such that each dp rank gets similar total tokens""" attention_mask = batch.batch["attention_mask"] batch_size = attention_mask.shape[0] - global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,) + global_seqlen_lst = calculate_workload(global_seqlen_lst) world_size = self.actor_rollout_wg.world_size - global_partition_lst = get_seqlen_balanced_partitions( - global_seqlen_lst, k_partitions=world_size, equal_size=True - ) + if keep_minibatch: + # Decouple the DP balancing and mini-batching. + minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size") + minibatch_num = len(global_seqlen_lst) // minibatch_size + global_partition_lst = [[] for _ in range(world_size)] + for i in range(minibatch_num): + rearrange_minibatch_lst = get_seqlen_balanced_partitions( + global_seqlen_lst[i * minibatch_size : (i + 1) * minibatch_size], + k_partitions=world_size, + equal_size=True, + ) + for j, part in enumerate(rearrange_minibatch_lst): + global_partition_lst[j].extend([x + minibatch_size * i for x in part]) + else: + global_partition_lst = get_seqlen_balanced_partitions( + global_seqlen_lst, k_partitions=world_size, equal_size=True + ) + # Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel. + for idx, partition in enumerate(global_partition_lst): + partition.sort(key=lambda x: (global_seqlen_lst[x], x)) + ordered_partition = partition[::2] + partition[1::2][::-1] + global_partition_lst[idx] = ordered_partition # reorder based on index. The data will be automatically equally partitioned by dispatch function global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) batch.reorder(global_idx) @@ -1038,7 +1058,6 @@ def fit(self): # NOTE: This usually changes the order of data in the `batch`, # which won't affect the advantage calculation (since it's based on uid), # but might affect the loss calculation (due to the change of mini-batching). - # TODO: Decouple the DP balancing and mini-batching. if self.config.trainer.balance_batch: self._balance_batch(batch, metrics=metrics) diff --git a/verl/utils/seqlen_balancing.py b/verl/utils/seqlen_balancing.py index 5354d5114e6..bc5588f7ac6 100644 --- a/verl/utils/seqlen_balancing.py +++ b/verl/utils/seqlen_balancing.py @@ -24,6 +24,16 @@ from verl.utils.device import get_device_name +def calculate_workload(seqlen_list: list[int]): + """ + Calculate the workload for a dense transformer block based on sequence length. + FLOPs = 12 * hidden_size^2 * seqlen + 2 * hidden_size * seqlen^2 + Hardcodes the constants by a 7B model (hidden_size=4096), + so the FLOPs are propotional to (6 * 4096 * seqlen + seqlen^2). + """ + return 24576 * seqlen_list + seqlen_list**2 + + def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool): # see: https://en.wikipedia.org/wiki/Largest_differencing_method class Set: @@ -298,20 +308,22 @@ def rearrange_micro_batches( if num_batches_divided_by is not None: num_micro_batches = roundup_divisible(num_micro_batches, num_batches_divided_by) - seq_len_effective = seq_len_effective.tolist() assert num_micro_batches <= len(seq_len_effective) - micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False) + workloads = calculate_workload(seq_len_effective) + micro_bsz_idx = get_seqlen_balanced_partitions(workloads, num_micro_batches, equal_size=False) if use_dynamic_bsz_balance: # Use the sum of squared sequence lengths to approximate attention computation workload micro_bsz_idx.sort( key=lambda partition: ( - sum(seq_len_effective[idx] ** 2 for idx in partition), - min(partition) if partition else 0, + sum(workloads[idx] for idx in partition), + partition[0] if partition else 0, ), reverse=True, ) + # Place smaller micro-batches at both ends to reduce the bubbles exposed during the warm-up and cool-down. + micro_bsz_idx = micro_bsz_idx[::2][::-1] + micro_bsz_idx[1::2] micro_batches = []