Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from verl.utils.debug import marked_timer
from verl.utils.metric import reduce_metrics
from verl.utils.rollout_skip import RolloutSkip
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from verl.utils.seqlen_balancing import calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance
from verl.utils.torch_functional import masked_mean
from verl.utils.tracking import ValidationGenerationsLogger

Expand Down Expand Up @@ -900,15 +900,35 @@ def _stop_profiling(self, do_profile: bool) -> None:
if self.use_rm:
self.rm_wg.stop_profile()

def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"):
def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen", keep_minibatch=False):
"""Reorder the data on single controller such that each dp rank gets similar total tokens"""
attention_mask = batch.batch["attention_mask"]
batch_size = attention_mask.shape[0]
global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,)
global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,)
global_seqlen_lst = calculate_workload(global_seqlen_lst)
world_size = self.actor_rollout_wg.world_size
global_partition_lst = get_seqlen_balanced_partitions(
global_seqlen_lst, k_partitions=world_size, equal_size=True
)
if keep_minibatch:
# Decouple the DP balancing and mini-batching.
minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size")
minibatch_num = len(global_seqlen_lst) // minibatch_size
global_partition_lst = [[] for _ in range(world_size)]
for i in range(minibatch_num):
rearrange_minibatch_lst = get_seqlen_balanced_partitions(
global_seqlen_lst[i * minibatch_size : (i + 1) * minibatch_size],
k_partitions=world_size,
equal_size=True,
)
for j, part in enumerate(rearrange_minibatch_lst):
global_partition_lst[j].extend([x + minibatch_size * i for x in part])
else:
global_partition_lst = get_seqlen_balanced_partitions(
global_seqlen_lst, k_partitions=world_size, equal_size=True
)
# Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel.
for idx, partition in enumerate(global_partition_lst):
partition.sort(key=lambda x: (global_seqlen_lst[x], x))
ordered_partition = partition[::2] + partition[1::2][::-1]
global_partition_lst[idx] = ordered_partition
# reorder based on index. The data will be automatically equally partitioned by dispatch function
global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
batch.reorder(global_idx)
Expand Down Expand Up @@ -1038,7 +1058,6 @@ def fit(self):
# NOTE: This usually changes the order of data in the `batch`,
# which won't affect the advantage calculation (since it's based on uid),
# but might affect the loss calculation (due to the change of mini-batching).
# TODO: Decouple the DP balancing and mini-batching.
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics)

Expand Down
20 changes: 16 additions & 4 deletions verl/utils/seqlen_balancing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@
from verl.utils.device import get_device_name


def calculate_workload(seqlen_list: list[int]):
"""
Calculate the workload for a dense transformer block based on sequence length.
FLOPs = 12 * hidden_size^2 * seqlen + 2 * hidden_size * seqlen^2
Hardcodes the constants by a 7B model (hidden_size=4096),
so the FLOPs are propotional to (6 * 4096 * seqlen + seqlen^2).
"""
return 24576 * seqlen_list + seqlen_list**2


def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool):
# see: https://en.wikipedia.org/wiki/Largest_differencing_method
class Set:
Expand Down Expand Up @@ -298,20 +308,22 @@ def rearrange_micro_batches(
if num_batches_divided_by is not None:
num_micro_batches = roundup_divisible(num_micro_batches, num_batches_divided_by)

seq_len_effective = seq_len_effective.tolist()
assert num_micro_batches <= len(seq_len_effective)

micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False)
workloads = calculate_workload(seq_len_effective)
micro_bsz_idx = get_seqlen_balanced_partitions(workloads, num_micro_batches, equal_size=False)

if use_dynamic_bsz_balance:
# Use the sum of squared sequence lengths to approximate attention computation workload
micro_bsz_idx.sort(
key=lambda partition: (
sum(seq_len_effective[idx] ** 2 for idx in partition),
min(partition) if partition else 0,
sum(workloads[idx] for idx in partition),
partition[0] if partition else 0,
),
reverse=True,
)
# Place smaller micro-batches at both ends to reduce the bubbles exposed during the warm-up and cool-down.
micro_bsz_idx = micro_bsz_idx[::2][::-1] + micro_bsz_idx[1::2]

micro_batches = []

Expand Down
Loading