From 68aaa4904b8dfb6cc791fdcee613edc681a8a198 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 31 Mar 2024 19:27:08 +0800 Subject: [PATCH] use log1p in orpo loss https://github.com/huggingface/trl/pull/1491 --- src/llmtuner/train/orpo/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmtuner/train/orpo/trainer.py b/src/llmtuner/train/orpo/trainer.py index 291351e430..af34b55ed5 100644 --- a/src/llmtuner/train/orpo/trainer.py +++ b/src/llmtuner/train/orpo/trainer.py @@ -84,7 +84,7 @@ def odds_ratio_loss( # Derived from Eqs. (4) and (7) from https://arxiv.org/abs/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x) log_odds = (chosen_logps - rejected_logps) - ( - torch.log(1 - torch.exp(chosen_logps)) - torch.log(1 - torch.exp(rejected_logps)) + torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps)) ) ratio = F.logsigmoid(log_odds) losses = self.beta * ratio