From d71878e9474385403dbfbb0224f3a3c6e11fb861 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Wed, 22 Oct 2025 10:49:21 +0800 Subject: [PATCH 1/6] add tis fall back for ppo_policy_losss --- tests/algorithm/policy_loss_test.py | 26 ++++++++++ .../policy_loss_fn/ppo_policy_loss.py | 49 +++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py index 134635c05a..dc6e896b1d 100644 --- a/tests/algorithm/policy_loss_test.py +++ b/tests/algorithm/policy_loss_test.py @@ -108,3 +108,29 @@ def test_mix_policy_loss(self): self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_loss"]), pg_loss)) self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss)) self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss)) + + def test_ppo_policy_loss_with_truncate_is(self): + """Test PPO policy loss with truncate large IS enabled.""" + policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + # Enable truncate large IS with custom bounds + policy_loss_fn_args["truncate_large_is"] = True + policy_loss_fn_args["truncate_is_range_low"] = 0.0 + policy_loss_fn_args["truncate_is_range_high"] = 2.0 + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + + # Expected values with IS truncation enabled + # Need calculations for these values + # ppo_loss_truncated = torch.tensor(0.27213451266288757) + # pg_clipfrac_truncated = torch.tensor(0.36458331346511841) + # ppo_kl_truncated = torch.tensor(-0.21663446724414825) + + # self.assertTrue(torch.allclose(loss, ppo_loss_truncated)) + # self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac_truncated)) + # self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl_truncated)) + # self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_truncated)) + # Check that IS truncation metric is present + self.assertIn("is_truncate_frac", metrics) + self.assertGreaterEqual(metrics["is_truncate_frac"], 0.0) + self.assertLessEqual(metrics["is_truncate_frac"], 1.0) diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py index 9c9bbaf2a5..016cd9320c 100644 --- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -20,7 +20,24 @@ def __init__( clip_range_low: Optional[float] = None, clip_range_high: Optional[float] = None, loss_agg_mode: Optional[str] = "token-mean", + truncate_large_is: bool = False, + truncate_is_range_low: Optional[float] = None, + truncate_is_range_high: Optional[float] = None, ) -> None: + """ + Initialize PPO policy loss function. + + Args: + backend: Backend framework (default: "verl") + clip_range: Symmetric clipping range for PPO + clip_range_low: Lower bound for clipping (1.0 - clip_range_low) + clip_range_high: Upper bound for clipping (1.0 + clip_range_high) + loss_agg_mode: Loss aggregation mode (default: "token-mean") + truncate_large_is: Whether to truncate large importance sampling ratios + to handle computation errors between VLLM and transformer calculations + truncate_is_range_low: Lower bound for IS ratio truncation (default: 0.0) + truncate_is_range_high: Upper bound for IS ratio truncation (default: 2.0) + """ super().__init__(backend=backend) if clip_range_low is None: self.clip_range_low = clip_range @@ -33,6 +50,18 @@ def __init__( assert self.clip_range_low is not None, "clip_range_low must be specified." assert self.clip_range_high is not None, "clip_range_high must be specified." self.loss_agg_mode = loss_agg_mode + + # Truncate large IS configuration + self.truncate_large_is = truncate_large_is + if truncate_large_is: + self.truncate_is_range_low = truncate_is_range_low if truncate_is_range_low is not None else 0.0 + self.truncate_is_range_high = truncate_is_range_high if truncate_is_range_high is not None else 2.0 + assert self.truncate_is_range_low >= 0.0, "truncate_is_range_low must be non-negative." + assert self.truncate_is_range_high > self.truncate_is_range_low, \ + "truncate_is_range_high must be greater than truncate_is_range_low." + else: + self.truncate_is_range_low = None + self.truncate_is_range_high = None def __call__( # type: ignore self, @@ -46,6 +75,18 @@ def __call__( # type: ignore ratio = torch.exp(negative_approx_kl) ppo_kl = masked_mean(-negative_approx_kl, action_mask) + # Truncate large IS ratios if enabled + # This helps stabilize training when there are computation errors between + # VLLM and transformer calculations, especially for small probabilities + if self.truncate_large_is: + ratio_before_truncate = ratio.clone() + ratio = torch.clamp(ratio, self.truncate_is_range_low, self.truncate_is_range_high) + # Track how often truncation occurs + is_truncated = torch.ne(ratio, ratio_before_truncate).float() + is_truncate_frac = masked_mean(is_truncated, action_mask) + else: + is_truncate_frac = torch.tensor(0.0) + pg_losses = -advantages * ratio pg_losses2 = -advantages * torch.clamp( ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore @@ -60,6 +101,11 @@ def __call__( # type: ignore "ppo_kl": ppo_kl.detach().item(), "pg_loss": pg_loss.detach().item(), } + + # Add IS truncation metrics if enabled + if self.truncate_large_is: + metrics["is_truncate_frac"] = is_truncate_frac.detach().item() + return pg_loss, metrics @classmethod @@ -67,4 +113,7 @@ def default_args(cls) -> Dict: return { "clip_range": 0.2, "loss_agg_mode": "token-mean", + "truncate_large_is": False, + "truncate_is_range_low": 0.0, + "truncate_is_range_high": 2.0, } From 334ff8fcc22b38f3908b94ac8b402f60be91b856 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Wed, 22 Oct 2025 10:53:40 +0800 Subject: [PATCH 2/6] fix pre-commit --- tests/algorithm/policy_loss_test.py | 6 ++--- .../policy_loss_fn/ppo_policy_loss.py | 23 +++++++++++-------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py index dc6e896b1d..ba0d055faf 100644 --- a/tests/algorithm/policy_loss_test.py +++ b/tests/algorithm/policy_loss_test.py @@ -119,13 +119,13 @@ def test_ppo_policy_loss_with_truncate_is(self): policy_loss_fn_args["truncate_is_range_high"] = 2.0 policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) - + # Expected values with IS truncation enabled # Need calculations for these values - # ppo_loss_truncated = torch.tensor(0.27213451266288757) + # ppo_loss_truncated = torch.tensor(0.27213451266288757) # pg_clipfrac_truncated = torch.tensor(0.36458331346511841) # ppo_kl_truncated = torch.tensor(-0.21663446724414825) - + # self.assertTrue(torch.allclose(loss, ppo_loss_truncated)) # self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac_truncated)) # self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl_truncated)) diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py index 016cd9320c..d98e3b2f2c 100644 --- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -26,14 +26,14 @@ def __init__( ) -> None: """ Initialize PPO policy loss function. - + Args: backend: Backend framework (default: "verl") clip_range: Symmetric clipping range for PPO clip_range_low: Lower bound for clipping (1.0 - clip_range_low) clip_range_high: Upper bound for clipping (1.0 + clip_range_high) loss_agg_mode: Loss aggregation mode (default: "token-mean") - truncate_large_is: Whether to truncate large importance sampling ratios + truncate_large_is: Whether to truncate large importance sampling ratios to handle computation errors between VLLM and transformer calculations truncate_is_range_low: Lower bound for IS ratio truncation (default: 0.0) truncate_is_range_high: Upper bound for IS ratio truncation (default: 2.0) @@ -50,15 +50,20 @@ def __init__( assert self.clip_range_low is not None, "clip_range_low must be specified." assert self.clip_range_high is not None, "clip_range_high must be specified." self.loss_agg_mode = loss_agg_mode - + # Truncate large IS configuration self.truncate_large_is = truncate_large_is if truncate_large_is: - self.truncate_is_range_low = truncate_is_range_low if truncate_is_range_low is not None else 0.0 - self.truncate_is_range_high = truncate_is_range_high if truncate_is_range_high is not None else 2.0 + self.truncate_is_range_low = ( + truncate_is_range_low if truncate_is_range_low is not None else 0.0 + ) + self.truncate_is_range_high = ( + truncate_is_range_high if truncate_is_range_high is not None else 2.0 + ) assert self.truncate_is_range_low >= 0.0, "truncate_is_range_low must be non-negative." - assert self.truncate_is_range_high > self.truncate_is_range_low, \ - "truncate_is_range_high must be greater than truncate_is_range_low." + assert ( + self.truncate_is_range_high > self.truncate_is_range_low + ), "truncate_is_range_high must be greater than truncate_is_range_low." else: self.truncate_is_range_low = None self.truncate_is_range_high = None @@ -101,11 +106,11 @@ def __call__( # type: ignore "ppo_kl": ppo_kl.detach().item(), "pg_loss": pg_loss.detach().item(), } - + # Add IS truncation metrics if enabled if self.truncate_large_is: metrics["is_truncate_frac"] = is_truncate_frac.detach().item() - + return pg_loss, metrics @classmethod From 2db4d99301fd8fb476ad3f8b788e5b6f00ec928a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Wed, 22 Oct 2025 12:41:47 +0800 Subject: [PATCH 3/6] add correct test values --- tests/algorithm/policy_loss_test.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py index ba0d055faf..7f90f2d757 100644 --- a/tests/algorithm/policy_loss_test.py +++ b/tests/algorithm/policy_loss_test.py @@ -113,24 +113,27 @@ def test_ppo_policy_loss_with_truncate_is(self): """Test PPO policy loss with truncate large IS enabled.""" policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo") policy_loss_fn_args = policy_loss_fn_cls.default_args() - # Enable truncate large IS with custom bounds + # Enable truncate large IS with default bounds [0.0, 2.0] policy_loss_fn_args["truncate_large_is"] = True policy_loss_fn_args["truncate_is_range_low"] = 0.0 policy_loss_fn_args["truncate_is_range_high"] = 2.0 policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) - # Expected values with IS truncation enabled - # Need calculations for these values - # ppo_loss_truncated = torch.tensor(0.27213451266288757) - # pg_clipfrac_truncated = torch.tensor(0.36458331346511841) - # ppo_kl_truncated = torch.tensor(-0.21663446724414825) + # Expected values with IS truncation enabled (range: [0.0, 2.0]) + ppo_loss_truncated = torch.tensor(0.2230827361345291) + pg_clipfrac_truncated = torch.tensor(0.3541666567325592) + ppo_kl_truncated = torch.tensor(-0.21663446724414825) + is_truncate_frac_expected = torch.tensor(0.2708333432674408) - # self.assertTrue(torch.allclose(loss, ppo_loss_truncated)) - # self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac_truncated)) - # self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl_truncated)) - # self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_truncated)) - # Check that IS truncation metric is present + self.assertTrue(torch.allclose(loss, ppo_loss_truncated)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac_truncated)) + self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl_truncated)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_truncated)) + # Check that IS truncation metric is present and has expected value self.assertIn("is_truncate_frac", metrics) + self.assertTrue( + torch.allclose(torch.tensor(metrics["is_truncate_frac"]), is_truncate_frac_expected) + ) self.assertGreaterEqual(metrics["is_truncate_frac"], 0.0) self.assertLessEqual(metrics["is_truncate_frac"], 1.0) From db021f5a962d1649542fc3d87996ebb900e6724d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Mon, 27 Oct 2025 10:42:21 +0800 Subject: [PATCH 4/6] resolve comments --- .../policy_loss_fn/ppo_policy_loss.py | 34 +++++++++---------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py index d98e3b2f2c..47fdf715ac 100644 --- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -21,8 +21,8 @@ def __init__( clip_range_high: Optional[float] = None, loss_agg_mode: Optional[str] = "token-mean", truncate_large_is: bool = False, - truncate_is_range_low: Optional[float] = None, - truncate_is_range_high: Optional[float] = None, + truncate_is_range_low: Optional[float] = 0.0, + truncate_is_range_high: Optional[float] = 2.0, ) -> None: """ Initialize PPO policy loss function. @@ -34,7 +34,7 @@ def __init__( clip_range_high: Upper bound for clipping (1.0 + clip_range_high) loss_agg_mode: Loss aggregation mode (default: "token-mean") truncate_large_is: Whether to truncate large importance sampling ratios - to handle computation errors between VLLM and transformer calculations + to handle calculation discrepancies between rollout and training engines truncate_is_range_low: Lower bound for IS ratio truncation (default: 0.0) truncate_is_range_high: Upper bound for IS ratio truncation (default: 2.0) """ @@ -54,19 +54,15 @@ def __init__( # Truncate large IS configuration self.truncate_large_is = truncate_large_is if truncate_large_is: - self.truncate_is_range_low = ( - truncate_is_range_low if truncate_is_range_low is not None else 0.0 - ) - self.truncate_is_range_high = ( - truncate_is_range_high if truncate_is_range_high is not None else 2.0 - ) + self.truncate_is_range_low = truncate_is_range_low + self.truncate_is_range_high = truncate_is_range_high + assert ( + self.truncate_is_range_low is not None and self.truncate_is_range_high is not None + ), "truncate_is_range_low and truncate_is_range_high must be specified." assert self.truncate_is_range_low >= 0.0, "truncate_is_range_low must be non-negative." assert ( self.truncate_is_range_high > self.truncate_is_range_low ), "truncate_is_range_high must be greater than truncate_is_range_low." - else: - self.truncate_is_range_low = None - self.truncate_is_range_high = None def __call__( # type: ignore self, @@ -81,14 +77,16 @@ def __call__( # type: ignore ppo_kl = masked_mean(-negative_approx_kl, action_mask) # Truncate large IS ratios if enabled - # This helps stabilize training when there are computation errors between - # VLLM and transformer calculations, especially for small probabilities + # This helps stabilize training when there are calculation discrepancies between + # rollout and training engines, especially for small probabilities if self.truncate_large_is: - ratio_before_truncate = ratio.clone() + # Track how often truncation occurs (before actually truncating) + # More efficient than cloning: directly check which values fall outside bounds + ratio_detached = ratio.detach() + is_truncate_frac = masked_mean( + (ratio_detached < self.truncate_is_range_low).float(), action_mask + ) + masked_mean((ratio_detached > self.truncate_is_range_high).float(), action_mask) ratio = torch.clamp(ratio, self.truncate_is_range_low, self.truncate_is_range_high) - # Track how often truncation occurs - is_truncated = torch.ne(ratio, ratio_before_truncate).float() - is_truncate_frac = masked_mean(is_truncated, action_mask) else: is_truncate_frac = torch.tensor(0.0) From a0e805dab295fdb8bb088f8f3bd249846ee326f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Mon, 27 Oct 2025 10:49:48 +0800 Subject: [PATCH 5/6] fix style --- trinity/algorithm/policy_loss_fn/ppo_policy_loss.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py index 47fdf715ac..3f539ff75f 100644 --- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -57,8 +57,11 @@ def __init__( self.truncate_is_range_low = truncate_is_range_low self.truncate_is_range_high = truncate_is_range_high assert ( - self.truncate_is_range_low is not None and self.truncate_is_range_high is not None - ), "truncate_is_range_low and truncate_is_range_high must be specified." + self.truncate_is_range_low is not None + ), "truncate_is_range_low must be specified." + assert ( + self.truncate_is_range_high is not None + ), "truncate_is_range_high must be specified." assert self.truncate_is_range_low >= 0.0, "truncate_is_range_low must be non-negative." assert ( self.truncate_is_range_high > self.truncate_is_range_low From 781e06b2a60bced124075d2d4ce2742091a1bade Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E6=98=8A?= Date: Mon, 27 Oct 2025 11:22:56 +0800 Subject: [PATCH 6/6] remove unneccsary lines --- trinity/algorithm/policy_loss_fn/ppo_policy_loss.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py index 3f539ff75f..b8cad22ce1 100644 --- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -90,8 +90,6 @@ def __call__( # type: ignore (ratio_detached < self.truncate_is_range_low).float(), action_mask ) + masked_mean((ratio_detached > self.truncate_is_range_high).float(), action_mask) ratio = torch.clamp(ratio, self.truncate_is_range_low, self.truncate_is_range_high) - else: - is_truncate_frac = torch.tensor(0.0) pg_losses = -advantages * ratio pg_losses2 = -advantages * torch.clamp(