Skip to content

Commit

Permalink
Update dpo.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xu-song authored Jun 18, 2024
1 parent ba95068 commit 5a16fa2
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions xtuner/model/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def __init__(self,
ref_llm = self._build_llm_from_cfg(ref_llm, kwargs.get("use_varlen_attn"), kwargs.get("max_position_embeddings"))
self.ref_llm = disable_grad(ref_llm)
else:
if not self.use_lora:
self.ref_llm = create_reference_model(self.llm)
self.ref_llm = None if self.use_lora else create_reference_model(self.llm)


def _gather_masked_logits(self, logits, labels, mask):
logits = torch.gather(
Expand Down

0 comments on commit 5a16fa2

Please sign in to comment.