diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py index d5e17a83e9..d38b3b2053 100644 --- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -23,6 +23,7 @@ def __init__( loss_agg_mode: Optional[str] = "token-mean", enable_sequence_masking: bool = False, # introduced in DeepseekV3.2 delta_sequence_masking: float = 0.1, + fallback_to_policy_gradient: bool = False, ) -> None: super().__init__(backend=backend) if clip_range_low is None: @@ -40,6 +41,7 @@ def __init__( self.loss_agg_mode = loss_agg_mode self.enable_sequence_masking = enable_sequence_masking self.delta_sequence_masking = delta_sequence_masking + self.fallback_to_policy_gradient = fallback_to_policy_gradient def __call__( # type: ignore self, @@ -50,6 +52,9 @@ def __call__( # type: ignore **kwargs, ) -> Tuple[torch.Tensor, Dict]: negative_approx_kl = logprob - old_logprob + if self.fallback_to_policy_gradient: + # ignore vllm logprob difference and use pure policy gradient loss + negative_approx_kl = logprob - logprob.detach() # Clamp negative_approx_kl for stability negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) ratio = torch.exp(negative_approx_kl) @@ -119,4 +124,5 @@ def default_args(cls) -> Dict: "loss_agg_mode": "token-mean", "enable_sequence_masking": False, "delta_sequence_masking": 0.1, + "fallback_to_policy_gradient": False, }