From ff226e18f2d686dffe5cde842b78d7f9fe996693 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 19 Jul 2024 10:03:40 +0800 Subject: [PATCH] [Enhance]: Fix sequence parallel memory bottleneck in DPO & ORPO (#830) * [WIP]: Fix sequence parallel memory bottleneck in DPO * loss mask before split * refactor orpo --- xtuner/model/dpo.py | 115 ++++++++++++++++++------------------ xtuner/model/orpo.py | 135 ++++++++++++++++++++++--------------------- 2 files changed, 128 insertions(+), 122 deletions(-) diff --git a/xtuner/model/dpo.py b/xtuner/model/dpo.py index b46ea1c50..9a7b97a19 100644 --- a/xtuner/model/dpo.py +++ b/xtuner/model/dpo.py @@ -62,77 +62,66 @@ def _gather_masked_logits(self, logits, labels, mask): def get_logps( self, - all_logits, # bs, seqlen,vocab_size - all_ref_logits, # bs, seqlen,vocab_size - labels, # bs, seqlen + policy_logps, # bs, seqlen,vocab_size + ref_logps, # bs, seqlen,vocab_size + loss_mask, # bs, seqlen ): - labels = labels[:, 1:].clone() - all_logits = all_logits[:, :-1, :] - all_ref_logits = all_ref_logits[:, :-1, :] - - labels[labels == -100] = 0 - loss_mask = labels != 0 - all_logps = self._gather_masked_logits(all_logits, labels, - loss_mask).sum(-1) - all_ref_logps = self._gather_masked_logits(all_ref_logits, labels, - loss_mask).sum(-1) + policy_logps = policy_logps[:, :-1].sum(-1) + ref_logps = ref_logps[:, :-1].sum(-1) + loss_mask = loss_mask[:, :-1] if self.loss_type == 'ipo': # average_log_prob - all_logps = all_logps / loss_mask.sum(-1) - all_ref_logps = all_ref_logps / loss_mask.sum(-1) + policy_logps = policy_logps / loss_mask.sum(-1) + ref_logps = ref_logps / loss_mask.sum(-1) - policy_chosen_logps = all_logps[::2] - policy_rejected_logps = all_logps[1::2] - reference_chosen_logps = all_ref_logps[::2] - reference_rejected_logps = all_ref_logps[1::2] + policy_chosen_logps = policy_logps[::2] + policy_rejected_logps = policy_logps[1::2] + reference_chosen_logps = ref_logps[::2] + reference_rejected_logps = ref_logps[1::2] return (policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps) - def get_var_len_atten_logps(self, all_logits, all_ref_logits, labels, + def get_var_len_atten_logps(self, policy_logps, ref_logps, loss_mask, cu_seqlens, attention_mask): seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() # unpack sequence - unpacked_logits = torch.split(all_logits, seqlens, dim=1) - unpacked_ref_logits = torch.split(all_ref_logits, seqlens, dim=1) - unpacked_labels = torch.split(labels, seqlens, dim=1) + unpacked_policy_logps = torch.split(policy_logps, seqlens, dim=1) + unpacked_ref_logps = torch.split(ref_logps, seqlens, dim=1) + unpacked_loss_mask = torch.split(loss_mask, seqlens, dim=1) if attention_mask is not None: # It indicate that we pad the original sequence, labels, # position_ids and cumulative_len for sequence parallel if the # attention_mask is not None. # We then need to remove the padded segments. assert False in attention_mask - unpacked_logits = unpacked_logits[:-1] - unpacked_ref_logits = unpacked_ref_logits[:-1] - unpacked_labels = unpacked_labels[:-1] - assert len(unpacked_logits) % 2 == 0 + unpacked_policy_logps = unpacked_policy_logps[:-1] + unpacked_ref_logps = unpacked_ref_logps[:-1] + unpacked_loss_mask = unpacked_loss_mask[:-1] + assert len(unpacked_policy_logps) % 2 == 0 - def compute_logps(_logits, _labels): - _labels = _labels[:, 1:].clone() - _logits = _logits[:, :-1, :] - _labels[_labels == -100] = 0 - loss_mask = _labels != 0 - logps = self._gather_masked_logits(_logits, _labels, loss_mask) - logps = logps.sum(-1) + def compute_logps(_logps, _mask): + _logps = _logps[:, :-1].sum(-1) + _mask = _mask[:, :-1] if self.loss_type == 'ipo': - logps /= loss_mask.sum(-1) - return logps + _logps /= _mask.sum(-1) + return _logps (policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps) = [], [], [], [] - for i in range(len(unpacked_logits) // 2): - chosen = unpacked_logits[2 * i] - rejected = unpacked_logits[2 * i + 1] - chosen_ref = unpacked_ref_logits[2 * i] - rejected_ref = unpacked_ref_logits[2 * i + 1] - chosen_label = unpacked_labels[2 * i] - rejected_label = unpacked_labels[2 * i + 1] - policy_chosen_logps.append(compute_logps(chosen, chosen_label)) + for i in range(len(unpacked_policy_logps) // 2): + chosen = unpacked_policy_logps[2 * i] + rejected = unpacked_policy_logps[2 * i + 1] + chosen_ref = unpacked_ref_logps[2 * i] + rejected_ref = unpacked_ref_logps[2 * i + 1] + chosen_mask = unpacked_loss_mask[2 * i] + rejected_mask = unpacked_loss_mask[2 * i + 1] + policy_chosen_logps.append(compute_logps(chosen, chosen_mask)) policy_rejected_logps.append( - compute_logps(rejected, rejected_label)) + compute_logps(rejected, rejected_mask)) reference_chosen_logps.append( - compute_logps(chosen_ref, chosen_label)) + compute_logps(chosen_ref, chosen_mask)) reference_rejected_logps.append( - compute_logps(rejected_ref, rejected_label)) + compute_logps(rejected_ref, rejected_mask)) return (torch.stack(policy_chosen_logps), torch.stack(policy_rejected_logps), @@ -142,7 +131,7 @@ def compute_logps(_logits, _labels): @staticmethod def _split_for_sequence_parallel(data): # attention mask should not be split - ARGS_NEED_TO_SPLIT = ('input_ids', 'position_ids') + ARGS_NEED_TO_SPLIT = ('input_ids', 'position_ids', 'labels') sp_group = get_sequence_parallel_group() for key in ARGS_NEED_TO_SPLIT: val = data.get(key, None) @@ -154,8 +143,14 @@ def _split_for_sequence_parallel(data): def compute_loss(self, data, data_samples=None): # modified from https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py # noqa - - labels = data.pop('labels') + # shift labels first and add a dummy label at the end, to support sequence parallel # noqa + data['labels'] = torch.cat( + (data['labels'][:, 1:], torch.zeros_like(data['labels'][:, :1])), + dim=1) + tmp_label = data['labels'].clone() + tmp_label[tmp_label == 0] = -100 + all_loss_mask = data[ + 'labels'] != -100 # loss mask of all tokens in all sp ranks # noqa if get_sequence_parallel_world_size() > 1: data = self._split_for_sequence_parallel(data) @@ -168,14 +163,22 @@ def compute_loss(self, data, data_samples=None): else: all_ref_logits = self.ref_llm(**data).logits + labels = data['labels'] + labels[labels == -100] = 0 + loss_mask = labels != 0 # loss mask in a single sp rank + policy_logps = self._gather_masked_logits(all_logits, labels, + loss_mask) + ref_logps = self._gather_masked_logits(all_ref_logits, labels, + loss_mask) + if get_sequence_parallel_world_size() > 1: - all_logits = gather_forward_split_backward( - all_logits, + policy_logps = gather_forward_split_backward( + policy_logps, dim=1, sp_group=get_sequence_parallel_group(), grad_scale='up') - all_ref_logits = gather_forward_split_backward( - all_ref_logits, + ref_logps = gather_forward_split_backward( + ref_logps, dim=1, sp_group=get_sequence_parallel_group(), grad_scale='up') @@ -184,7 +187,7 @@ def compute_loss(self, data, data_samples=None): (policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps) = self.get_logps( - all_logits, all_ref_logits, labels) + policy_logps, ref_logps, all_loss_mask) else: message_hub = MessageHub.get_instance('varlen_attn_args') rank = dist.get_rank() @@ -192,7 +195,7 @@ def compute_loss(self, data, data_samples=None): (policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps) = self.get_var_len_atten_logps( - all_logits, all_ref_logits, labels, cu_seqlens, + policy_logps, ref_logps, all_loss_mask, cu_seqlens, data['attention_mask']) pi_logratios = policy_chosen_logps - policy_rejected_logps diff --git a/xtuner/model/orpo.py b/xtuner/model/orpo.py index 5fb4b7d27..37264088a 100644 --- a/xtuner/model/orpo.py +++ b/xtuner/model/orpo.py @@ -34,17 +34,12 @@ def _gather_masked_logits(self, logits, labels, mask): def get_logps( self, - all_logits, # bs, seqlen,vocab_size - average_log_prob, # bs, seqlen,vocab_size - labels, # bs, seqlen + all_logps, # bs, seqlen + average_log_prob, + loss_mask, # bs, seqlen ): - labels = labels[:, 1:].clone() - all_logits = all_logits[:, :-1, :] - - labels[labels == -100] = 0 - loss_mask = labels != 0 - all_logps = self._gather_masked_logits(all_logits, labels, - loss_mask).sum(-1) + all_logps = all_logps[:, :-1].sum(-1) + loss_mask = loss_mask[:, :-1] if average_log_prob: # average_log_prob all_logps = all_logps / loss_mask.sum(-1) @@ -53,47 +48,44 @@ def get_logps( rejected_logps = all_logps[1::2] return chosen_logps, rejected_logps - def get_var_len_atten_logps(self, all_logits, average_log_prob, labels, + def get_var_len_atten_logps(self, all_logps, average_log_prob, loss_mask, cu_seqlens, attention_mask): seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() # unpack sequence - unpacked_logits = torch.split(all_logits, seqlens, dim=1) - unpacked_labels = torch.split(labels, seqlens, dim=1) + unpacked_logps = torch.split(all_logps, seqlens, dim=1) + unpacked_loss_mask = torch.split(loss_mask, seqlens, dim=1) if attention_mask is not None: # It indicate that we pad the original sequence, labels, # position_ids and cumulative_len for sequence parallel if the # attention_mask is not None. # We then need to remove the padded segments. assert False in attention_mask - unpacked_logits = unpacked_logits[:-1] - unpacked_labels = unpacked_labels[:-1] - assert len(unpacked_logits) % 2 == 0 - - def compute_logps(_logits, _labels): - _labels = _labels[:, 1:].clone() - _logits = _logits[:, :-1, :] - _labels[_labels == -100] = 0 - loss_mask = _labels != 0 - logps = self._gather_masked_logits(_logits, _labels, loss_mask) - logps = logps.sum(-1) + unpacked_logps = unpacked_logps[:-1] + unpacked_loss_mask = unpacked_loss_mask[:-1] + assert len(unpacked_logps) % 2 == 0 + + def compute_logps(_logps, _mask): + _logps = _logps[:, :-1].sum(-1) + _mask = _mask[:, :-1] if average_log_prob: - logps /= loss_mask.sum(-1) - return logps + _logps /= _mask.sum(-1) + return _logps chosen_logps, rejected_logps = [], [] - for i in range(len(unpacked_logits) // 2): - chosen = unpacked_logits[2 * i] - rejected = unpacked_logits[2 * i + 1] - chosen_label = unpacked_labels[2 * i] - rejected_label = unpacked_labels[2 * i + 1] - chosen_logps.append(compute_logps(chosen, chosen_label)) - rejected_logps.append(compute_logps(rejected, rejected_label)) + for i in range(len(unpacked_logps) // 2): + chosen = unpacked_logps[2 * i] + rejected = unpacked_logps[2 * i + 1] + chosen_mask = unpacked_loss_mask[2 * i] + rejected_mask = unpacked_loss_mask[2 * i + 1] + chosen_logps.append(compute_logps(chosen, chosen_mask)) + rejected_logps.append(compute_logps(rejected, rejected_mask)) return (torch.stack(chosen_logps), torch.stack(rejected_logps)) def cross_entropy_loss(self, logits, labels): logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() + # labels are already shifted, now we need to remove the last dummy label # noqa + labels = labels[..., :-1].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() logits = logits.view(-1, logits.shape[-1]) @@ -126,7 +118,8 @@ def odds_ratio_loss( @staticmethod def _split_for_sequence_parallel(data): # attention mask should not be split - ARGS_NEED_TO_SPLIT = ('input_ids', 'position_ids') + ARGS_NEED_TO_SPLIT = ('input_ids', 'position_ids', 'labels', + 'chosen_rejected_tag') sp_group = get_sequence_parallel_group() for key in ARGS_NEED_TO_SPLIT: val = data.get(key, None) @@ -137,53 +130,63 @@ def _split_for_sequence_parallel(data): return data def compute_loss(self, data, data_samples=None): - labels_ori = data.pop('labels') + # shift labels first and add a dummy label at the end, to support sequence parallel # noqa + data['labels'] = torch.cat( + (data['labels'][:, 1:], torch.zeros_like(data['labels'][:, :1])), + dim=1) + tmp_label = data['labels'].clone() + tmp_label[tmp_label == 0] = -100 + # loss mask of all tokens in all sp ranks + all_loss_mask = data['labels'] != -100 + + if self.use_varlen_attn: + # create a chosen rejected tag for varlen_attn ce loss + message_hub = MessageHub.get_instance('varlen_attn_args') + rank = dist.get_rank() + cu_seqlens = message_hub.get_info(f'cumulative_len_rank_{rank}') + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + + chosen_rejected_tag = torch.ones_like(data['labels']) + unpacked_tag = list( + torch.split(chosen_rejected_tag, seqlens, dim=1)) + # import pdb; pdb.set_trace() + for i in range(len(unpacked_tag) // 2): + # import pdb; pdb.set_trace() + unpacked_tag[2 * i + 1] *= 0 + chosen_rejected_tag = torch.cat(unpacked_tag, dim=1) + data['chosen_rejected_tag'] = chosen_rejected_tag if get_sequence_parallel_world_size() > 1: data = self._split_for_sequence_parallel(data) - + chosen_rejected_tag = data.pop('chosen_rejected_tag', None) all_logits = self.llm(**data).logits + + labels = data['labels'].clone() + labels[labels == -100] = 0 + loss_mask = labels != 0 # loss mask in a single sp rank + all_logps = self._gather_masked_logits(all_logits, labels, loss_mask) if get_sequence_parallel_world_size() > 1: - all_logits = gather_forward_split_backward( - all_logits, + all_logps = gather_forward_split_backward( + all_logps, dim=1, sp_group=get_sequence_parallel_group(), grad_scale='up') if not self.use_varlen_attn: chosen_nll_loss = self.cross_entropy_loss(all_logits[::2], - labels_ori.clone()[::2]) + data['labels'][::2]) chosen_logps, rejected_logps = self.get_logps( - all_logits, True, labels_ori) + all_logps, True, all_loss_mask) else: - message_hub = MessageHub.get_instance('varlen_attn_args') - rank = dist.get_rank() - cu_seqlens = message_hub.get_info(f'cumulative_len_rank_{rank}') - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - - attention_mask = data['attention_mask'] - if attention_mask is not None: - # It indicate that we pad the original sequence, labels, - # position_ids and cumulative_len for sequence parallel if the - # attention_mask is not None. - # We then need to remove the padded segments. - logits = torch.split(all_logits, seqlens, dim=1)[:-1] - assert len(logits) % 2 == 0 - chosen_logits = logits[::2] - labels = torch.split(labels_ori.clone(), seqlens, dim=1)[:-1] - assert len(labels) % 2 == 0 - chosen_labels = labels[::2] - else: - chosen_logits = torch.split(all_logits, seqlens, dim=1)[::2] - chosen_labels = torch.split( - labels_ori.clone(), seqlens, dim=1)[::2] - - chosen_logits = torch.cat(chosen_logits, dim=1) - chosen_labels = torch.cat(chosen_labels, dim=1) + chosen_idxs = chosen_rejected_tag == 1 + chosen_logits = all_logits[chosen_idxs] + chosen_labels = data['labels'][chosen_idxs] chosen_nll_loss = self.cross_entropy_loss(chosen_logits, chosen_labels) + chosen_logps, rejected_logps = self.get_var_len_atten_logps( - all_logits, True, labels_ori, cu_seqlens, attention_mask) + all_logps, True, all_loss_mask, cu_seqlens, + data['attention_mask']) (losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen) = self.odds_ratio_loss(chosen_logps, rejected_logps) losses = losses.mean()