diff --git a/xtuner/model/dpo.py b/xtuner/model/dpo.py index 1ba44ca9d..9a7b97a19 100644 --- a/xtuner/model/dpo.py +++ b/xtuner/model/dpo.py @@ -147,6 +147,10 @@ def compute_loss(self, data, data_samples=None): data['labels'] = torch.cat( (data['labels'][:, 1:], torch.zeros_like(data['labels'][:, :1])), dim=1) + tmp_label = data['labels'].clone() + tmp_label[tmp_label == 0] = -100 + all_loss_mask = data[ + 'labels'] != -100 # loss mask of all tokens in all sp ranks # noqa if get_sequence_parallel_world_size() > 1: data = self._split_for_sequence_parallel(data) @@ -161,7 +165,7 @@ def compute_loss(self, data, data_samples=None): labels = data['labels'] labels[labels == -100] = 0 - loss_mask = labels != 0 + loss_mask = labels != 0 # loss mask in a single sp rank policy_logps = self._gather_masked_logits(all_logits, labels, loss_mask) ref_logps = self._gather_masked_logits(all_ref_logits, labels, @@ -183,7 +187,7 @@ def compute_loss(self, data, data_samples=None): (policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps) = self.get_logps( - policy_logps, ref_logps, loss_mask) + policy_logps, ref_logps, all_loss_mask) else: message_hub = MessageHub.get_instance('varlen_attn_args') rank = dist.get_rank() @@ -191,7 +195,7 @@ def compute_loss(self, data, data_samples=None): (policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps) = self.get_var_len_atten_logps( - policy_logps, ref_logps, loss_mask, cu_seqlens, + policy_logps, ref_logps, all_loss_mask, cu_seqlens, data['attention_mask']) pi_logratios = policy_chosen_logps - policy_rejected_logps