diff --git a/xtuner/model/dpo.py b/xtuner/model/dpo.py index 25039dca2..e1d069489 100644 --- a/xtuner/model/dpo.py +++ b/xtuner/model/dpo.py @@ -54,8 +54,8 @@ def __init__(self, ref_llm = self._build_llm_from_cfg(ref_llm, kwargs.get("use_varlen_attn"), kwargs.get("max_position_embeddings")) self.ref_llm = disable_grad(ref_llm) else: - if not self.use_lora: - self.ref_llm = create_reference_model(self.llm) + self.ref_llm = None if self.use_lora else create_reference_model(self.llm) + def _gather_masked_logits(self, logits, labels, mask): logits = torch.gather(