Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance]: Fix sequence parallel memory bottleneck in DPO & ORPO #830

Merged
merged 3 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 59 additions & 56 deletions xtuner/model/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,77 +62,66 @@ def _gather_masked_logits(self, logits, labels, mask):

def get_logps(
self,
all_logits, # bs, seqlen,vocab_size
all_ref_logits, # bs, seqlen,vocab_size
labels, # bs, seqlen
policy_logps, # bs, seqlen,vocab_size
ref_logps, # bs, seqlen,vocab_size
loss_mask, # bs, seqlen
):
labels = labels[:, 1:].clone()
all_logits = all_logits[:, :-1, :]
all_ref_logits = all_ref_logits[:, :-1, :]

labels[labels == -100] = 0
loss_mask = labels != 0
all_logps = self._gather_masked_logits(all_logits, labels,
loss_mask).sum(-1)
all_ref_logps = self._gather_masked_logits(all_ref_logits, labels,
loss_mask).sum(-1)
policy_logps = policy_logps[:, :-1].sum(-1)
ref_logps = ref_logps[:, :-1].sum(-1)
loss_mask = loss_mask[:, :-1]

if self.loss_type == 'ipo': # average_log_prob
all_logps = all_logps / loss_mask.sum(-1)
all_ref_logps = all_ref_logps / loss_mask.sum(-1)
policy_logps = policy_logps / loss_mask.sum(-1)
ref_logps = ref_logps / loss_mask.sum(-1)

policy_chosen_logps = all_logps[::2]
policy_rejected_logps = all_logps[1::2]
reference_chosen_logps = all_ref_logps[::2]
reference_rejected_logps = all_ref_logps[1::2]
policy_chosen_logps = policy_logps[::2]
policy_rejected_logps = policy_logps[1::2]
reference_chosen_logps = ref_logps[::2]
reference_rejected_logps = ref_logps[1::2]
return (policy_chosen_logps, policy_rejected_logps,
reference_chosen_logps, reference_rejected_logps)

def get_var_len_atten_logps(self, all_logits, all_ref_logits, labels,
def get_var_len_atten_logps(self, policy_logps, ref_logps, loss_mask,
cu_seqlens, attention_mask):
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
# unpack sequence
unpacked_logits = torch.split(all_logits, seqlens, dim=1)
unpacked_ref_logits = torch.split(all_ref_logits, seqlens, dim=1)
unpacked_labels = torch.split(labels, seqlens, dim=1)
unpacked_policy_logps = torch.split(policy_logps, seqlens, dim=1)
unpacked_ref_logps = torch.split(ref_logps, seqlens, dim=1)
unpacked_loss_mask = torch.split(loss_mask, seqlens, dim=1)
if attention_mask is not None:
# It indicate that we pad the original sequence, labels,
# position_ids and cumulative_len for sequence parallel if the
# attention_mask is not None.
# We then need to remove the padded segments.
assert False in attention_mask
unpacked_logits = unpacked_logits[:-1]
unpacked_ref_logits = unpacked_ref_logits[:-1]
unpacked_labels = unpacked_labels[:-1]
assert len(unpacked_logits) % 2 == 0
unpacked_policy_logps = unpacked_policy_logps[:-1]
unpacked_ref_logps = unpacked_ref_logps[:-1]
unpacked_loss_mask = unpacked_loss_mask[:-1]
assert len(unpacked_policy_logps) % 2 == 0

def compute_logps(_logits, _labels):
_labels = _labels[:, 1:].clone()
_logits = _logits[:, :-1, :]
_labels[_labels == -100] = 0
loss_mask = _labels != 0
logps = self._gather_masked_logits(_logits, _labels, loss_mask)
logps = logps.sum(-1)
def compute_logps(_logps, _mask):
_logps = _logps[:, :-1].sum(-1)
_mask = _mask[:, :-1]
if self.loss_type == 'ipo':
logps /= loss_mask.sum(-1)
return logps
_logps /= _mask.sum(-1)
return _logps

(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps,
reference_rejected_logps) = [], [], [], []
for i in range(len(unpacked_logits) // 2):
chosen = unpacked_logits[2 * i]
rejected = unpacked_logits[2 * i + 1]
chosen_ref = unpacked_ref_logits[2 * i]
rejected_ref = unpacked_ref_logits[2 * i + 1]
chosen_label = unpacked_labels[2 * i]
rejected_label = unpacked_labels[2 * i + 1]
policy_chosen_logps.append(compute_logps(chosen, chosen_label))
for i in range(len(unpacked_policy_logps) // 2):
chosen = unpacked_policy_logps[2 * i]
rejected = unpacked_policy_logps[2 * i + 1]
chosen_ref = unpacked_ref_logps[2 * i]
rejected_ref = unpacked_ref_logps[2 * i + 1]
chosen_mask = unpacked_loss_mask[2 * i]
rejected_mask = unpacked_loss_mask[2 * i + 1]
policy_chosen_logps.append(compute_logps(chosen, chosen_mask))
policy_rejected_logps.append(
compute_logps(rejected, rejected_label))
compute_logps(rejected, rejected_mask))
reference_chosen_logps.append(
compute_logps(chosen_ref, chosen_label))
compute_logps(chosen_ref, chosen_mask))
reference_rejected_logps.append(
compute_logps(rejected_ref, rejected_label))
compute_logps(rejected_ref, rejected_mask))

return (torch.stack(policy_chosen_logps),
torch.stack(policy_rejected_logps),
Expand All @@ -142,7 +131,7 @@ def compute_logps(_logits, _labels):
@staticmethod
def _split_for_sequence_parallel(data):
# attention mask should not be split
ARGS_NEED_TO_SPLIT = ('input_ids', 'position_ids')
ARGS_NEED_TO_SPLIT = ('input_ids', 'position_ids', 'labels')
sp_group = get_sequence_parallel_group()
for key in ARGS_NEED_TO_SPLIT:
val = data.get(key, None)
Expand All @@ -154,8 +143,14 @@ def _split_for_sequence_parallel(data):

def compute_loss(self, data, data_samples=None):
# modified from https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py # noqa

labels = data.pop('labels')
# shift labels first and add a dummy label at the end, to support sequence parallel # noqa
data['labels'] = torch.cat(
(data['labels'][:, 1:], torch.zeros_like(data['labels'][:, :1])),
dim=1)
tmp_label = data['labels'].clone()
tmp_label[tmp_label == 0] = -100
all_loss_mask = data[
'labels'] != -100 # loss mask of all tokens in all sp ranks # noqa

if get_sequence_parallel_world_size() > 1:
data = self._split_for_sequence_parallel(data)
Expand All @@ -168,14 +163,22 @@ def compute_loss(self, data, data_samples=None):
else:
all_ref_logits = self.ref_llm(**data).logits

labels = data['labels']
labels[labels == -100] = 0
loss_mask = labels != 0 # loss mask in a single sp rank
policy_logps = self._gather_masked_logits(all_logits, labels,
loss_mask)
ref_logps = self._gather_masked_logits(all_ref_logits, labels,
loss_mask)

if get_sequence_parallel_world_size() > 1:
all_logits = gather_forward_split_backward(
all_logits,
policy_logps = gather_forward_split_backward(
policy_logps,
dim=1,
sp_group=get_sequence_parallel_group(),
grad_scale='up')
all_ref_logits = gather_forward_split_backward(
all_ref_logits,
ref_logps = gather_forward_split_backward(
ref_logps,
dim=1,
sp_group=get_sequence_parallel_group(),
grad_scale='up')
Expand All @@ -184,15 +187,15 @@ def compute_loss(self, data, data_samples=None):
(policy_chosen_logps, policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps) = self.get_logps(
all_logits, all_ref_logits, labels)
policy_logps, ref_logps, all_loss_mask)
else:
message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cu_seqlens = message_hub.get_info(f'cumulative_len_rank_{rank}')
(policy_chosen_logps, policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps) = self.get_var_len_atten_logps(
all_logits, all_ref_logits, labels, cu_seqlens,
policy_logps, ref_logps, all_loss_mask, cu_seqlens,
data['attention_mask'])

pi_logratios = policy_chosen_logps - policy_rejected_logps
Expand Down
135 changes: 69 additions & 66 deletions xtuner/model/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,12 @@ def _gather_masked_logits(self, logits, labels, mask):

def get_logps(
self,
all_logits, # bs, seqlen,vocab_size
average_log_prob, # bs, seqlen,vocab_size
labels, # bs, seqlen
all_logps, # bs, seqlen
average_log_prob,
loss_mask, # bs, seqlen
):
labels = labels[:, 1:].clone()
all_logits = all_logits[:, :-1, :]

labels[labels == -100] = 0
loss_mask = labels != 0
all_logps = self._gather_masked_logits(all_logits, labels,
loss_mask).sum(-1)
all_logps = all_logps[:, :-1].sum(-1)
loss_mask = loss_mask[:, :-1]

if average_log_prob: # average_log_prob
all_logps = all_logps / loss_mask.sum(-1)
Expand All @@ -53,47 +48,44 @@ def get_logps(
rejected_logps = all_logps[1::2]
return chosen_logps, rejected_logps

def get_var_len_atten_logps(self, all_logits, average_log_prob, labels,
def get_var_len_atten_logps(self, all_logps, average_log_prob, loss_mask,
cu_seqlens, attention_mask):
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
# unpack sequence
unpacked_logits = torch.split(all_logits, seqlens, dim=1)
unpacked_labels = torch.split(labels, seqlens, dim=1)
unpacked_logps = torch.split(all_logps, seqlens, dim=1)
unpacked_loss_mask = torch.split(loss_mask, seqlens, dim=1)
if attention_mask is not None:
# It indicate that we pad the original sequence, labels,
# position_ids and cumulative_len for sequence parallel if the
# attention_mask is not None.
# We then need to remove the padded segments.
assert False in attention_mask
unpacked_logits = unpacked_logits[:-1]
unpacked_labels = unpacked_labels[:-1]
assert len(unpacked_logits) % 2 == 0

def compute_logps(_logits, _labels):
_labels = _labels[:, 1:].clone()
_logits = _logits[:, :-1, :]
_labels[_labels == -100] = 0
loss_mask = _labels != 0
logps = self._gather_masked_logits(_logits, _labels, loss_mask)
logps = logps.sum(-1)
unpacked_logps = unpacked_logps[:-1]
unpacked_loss_mask = unpacked_loss_mask[:-1]
assert len(unpacked_logps) % 2 == 0

def compute_logps(_logps, _mask):
_logps = _logps[:, :-1].sum(-1)
_mask = _mask[:, :-1]
if average_log_prob:
logps /= loss_mask.sum(-1)
return logps
_logps /= _mask.sum(-1)
return _logps

chosen_logps, rejected_logps = [], []
for i in range(len(unpacked_logits) // 2):
chosen = unpacked_logits[2 * i]
rejected = unpacked_logits[2 * i + 1]
chosen_label = unpacked_labels[2 * i]
rejected_label = unpacked_labels[2 * i + 1]
chosen_logps.append(compute_logps(chosen, chosen_label))
rejected_logps.append(compute_logps(rejected, rejected_label))
for i in range(len(unpacked_logps) // 2):
chosen = unpacked_logps[2 * i]
rejected = unpacked_logps[2 * i + 1]
chosen_mask = unpacked_loss_mask[2 * i]
rejected_mask = unpacked_loss_mask[2 * i + 1]
chosen_logps.append(compute_logps(chosen, chosen_mask))
rejected_logps.append(compute_logps(rejected, rejected_mask))

return (torch.stack(chosen_logps), torch.stack(rejected_logps))

def cross_entropy_loss(self, logits, labels):
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# labels are already shifted, now we need to remove the last dummy label # noqa
labels = labels[..., :-1].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
Expand Down Expand Up @@ -126,7 +118,8 @@ def odds_ratio_loss(
@staticmethod
def _split_for_sequence_parallel(data):
# attention mask should not be split
ARGS_NEED_TO_SPLIT = ('input_ids', 'position_ids')
ARGS_NEED_TO_SPLIT = ('input_ids', 'position_ids', 'labels',
'chosen_rejected_tag')
sp_group = get_sequence_parallel_group()
for key in ARGS_NEED_TO_SPLIT:
val = data.get(key, None)
Expand All @@ -137,53 +130,63 @@ def _split_for_sequence_parallel(data):
return data

def compute_loss(self, data, data_samples=None):
labels_ori = data.pop('labels')
# shift labels first and add a dummy label at the end, to support sequence parallel # noqa
data['labels'] = torch.cat(
(data['labels'][:, 1:], torch.zeros_like(data['labels'][:, :1])),
dim=1)
tmp_label = data['labels'].clone()
tmp_label[tmp_label == 0] = -100
# loss mask of all tokens in all sp ranks
all_loss_mask = data['labels'] != -100

if self.use_varlen_attn:
# create a chosen rejected tag for varlen_attn ce loss
message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cu_seqlens = message_hub.get_info(f'cumulative_len_rank_{rank}')
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

chosen_rejected_tag = torch.ones_like(data['labels'])
unpacked_tag = list(
torch.split(chosen_rejected_tag, seqlens, dim=1))
# import pdb; pdb.set_trace()
for i in range(len(unpacked_tag) // 2):
# import pdb; pdb.set_trace()
unpacked_tag[2 * i + 1] *= 0
chosen_rejected_tag = torch.cat(unpacked_tag, dim=1)
data['chosen_rejected_tag'] = chosen_rejected_tag

if get_sequence_parallel_world_size() > 1:
data = self._split_for_sequence_parallel(data)

chosen_rejected_tag = data.pop('chosen_rejected_tag', None)
all_logits = self.llm(**data).logits

labels = data['labels'].clone()
labels[labels == -100] = 0
loss_mask = labels != 0 # loss mask in a single sp rank
all_logps = self._gather_masked_logits(all_logits, labels, loss_mask)
if get_sequence_parallel_world_size() > 1:
all_logits = gather_forward_split_backward(
all_logits,
all_logps = gather_forward_split_backward(
all_logps,
dim=1,
sp_group=get_sequence_parallel_group(),
grad_scale='up')

if not self.use_varlen_attn:
chosen_nll_loss = self.cross_entropy_loss(all_logits[::2],
labels_ori.clone()[::2])
data['labels'][::2])
chosen_logps, rejected_logps = self.get_logps(
all_logits, True, labels_ori)
all_logps, True, all_loss_mask)
else:
message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cu_seqlens = message_hub.get_info(f'cumulative_len_rank_{rank}')
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

attention_mask = data['attention_mask']
if attention_mask is not None:
# It indicate that we pad the original sequence, labels,
# position_ids and cumulative_len for sequence parallel if the
# attention_mask is not None.
# We then need to remove the padded segments.
logits = torch.split(all_logits, seqlens, dim=1)[:-1]
assert len(logits) % 2 == 0
chosen_logits = logits[::2]
labels = torch.split(labels_ori.clone(), seqlens, dim=1)[:-1]
assert len(labels) % 2 == 0
chosen_labels = labels[::2]
else:
chosen_logits = torch.split(all_logits, seqlens, dim=1)[::2]
chosen_labels = torch.split(
labels_ori.clone(), seqlens, dim=1)[::2]

chosen_logits = torch.cat(chosen_logits, dim=1)
chosen_labels = torch.cat(chosen_labels, dim=1)
chosen_idxs = chosen_rejected_tag == 1
chosen_logits = all_logits[chosen_idxs]
chosen_labels = data['labels'][chosen_idxs]
chosen_nll_loss = self.cross_entropy_loss(chosen_logits,
chosen_labels)

chosen_logps, rejected_logps = self.get_var_len_atten_logps(
all_logits, True, labels_ori, cu_seqlens, attention_mask)
all_logps, True, all_loss_mask, cu_seqlens,
data['attention_mask'])
(losses, chosen_rewards, rejected_rewards, log_odds_ratio,
log_odds_chosen) = self.odds_ratio_loss(chosen_logps, rejected_logps)
losses = losses.mean()
Expand Down
Loading