Skip to content
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ We list some algorithms supported by Trinity-RFT in the following table. For mor
| AsymRE [[Paper](https://arxiv.org/pdf/2506.20520)] | [[GSM8K Example](https://github.com/modelscope/Trinity-RFT/tree/main/examples/asymre_gsm8k)] | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/asymre_advantage.py)] | `algorithm_type: asymre` |
| CISPO [[Paper](https://arxiv.org/pdf/2506.13585)] | - | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` |
| SAPO [[Paper](https://arxiv.org/pdf/2511.20347)] | - | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` |
| On-Policy Distillation [[Blog](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[Paper](https://arxiv.org/pdf/2306.13649)] | [[GSM8K Example](https://github.com/modelscope/Trinity-RFT/tree/main/examples/on_policy_distill)] | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` |



Expand All @@ -142,7 +143,7 @@ We list some algorithms supported by Trinity-RFT in the following table. For mor
- [Step 2: prepare dataset and model](#step-2-prepare-dataset-and-model)
- [Step 3: configurations](#step-3-configurations)
- [Step 4: run the RFT process](#step-4-run-the-rft-process)
- [Contribution guide](#contribution-guide)
- [Contribution Guide](#contribution-guide)
- [Acknowledgements](#acknowledgements)
- [Citation](#citation)

Expand Down
1 change: 1 addition & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能:
| AsymRE [[论文](https://arxiv.org/pdf/2506.20520)] | [[GSM8K 例子](https://github.com/modelscope/Trinity-RFT/tree/main/examples/asymre_gsm8k)] | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/asymre_advantage.py)] | `algorithm_type: asymre` |
| CISPO [[论文](https://arxiv.org/pdf/2506.13585)] | - | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` |
| SAPO [[论文](https://arxiv.org/pdf/2511.20347)] | - | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` |
| On-Policy Distillation [[博客](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[论文](https://arxiv.org/pdf/2306.13649)] | [[GSM8K 示例](https://github.com/modelscope/Trinity-RFT/tree/main/examples/on_policy_distill)] | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` |



Expand Down
2 changes: 1 addition & 1 deletion benchmark/reports/gsm8k.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class VerlGSM8kWorkflow(Workflow):
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
auxiliary_models: Optional[List[ModelWrapper]] = None,
):
self.reset(task)
super().__init__(
Expand Down
Binary file added docs/sphinx_doc/assets/opd_acc.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/sphinx_doc/assets/opd_kl.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions docs/sphinx_doc/source/main.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ We list some algorithms supported by Trinity-RFT in the following table. For mor
| AsymRE [[Paper](https://arxiv.org/pdf/2506.20520)] | [[GSM8K Example](https://github.com/modelscope/Trinity-RFT/tree/main/examples/asymre_gsm8k)] | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/asymre_advantage.py)] | `algorithm_type: asymre` |
| CISPO [[Paper](https://arxiv.org/pdf/2506.13585)] | - | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` |
| SAPO [[Paper](https://arxiv.org/pdf/2511.20347)] | - | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` |
| On-Policy Distillation [[Blog](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[Paper](https://arxiv.org/pdf/2306.13649)] | [[GSM8K Example](https://github.com/modelscope/Trinity-RFT/tree/main/examples/on_policy_distill)] | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` |




Expand Down
11 changes: 6 additions & 5 deletions docs/sphinx_doc/source/tutorial/develop_workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,12 @@ class Workflow(ABC):
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
auxiliary_models: Optional[List[ModelWrapper]] = None,
):
self.task = task
self.model = model
self.auxiliary_models = auxiliary_models
self.auxiliary_model_wrappers = auxiliary_models
self.auxiliary_models = ... # OpenAI clients auto-derived from ModelWrapper

@abstractmethod
def run(self) -> List[Experience]:
Expand All @@ -110,7 +111,7 @@ During initialization, `Workflow` receives the following parameters:

- `task`({class}`trinity.common.workflows.Task`): A single data item from the task dataset.
- `model`({class}`trinity.common.models.model.ModelWrapper`): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`).
- `auxiliary_models`(`List[openai.OpenAI]`):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs.
- `auxiliary_models`(`List[ModelWrapper]`): A list of auxiliary model wrappers. You can access OpenAI clients via `self.auxiliary_models` (auto-derived based on workflow's `is_async`).

```{tip}
You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow.
Expand Down Expand Up @@ -440,10 +441,10 @@ class MyWorkflow(Workflow):
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
auxiliary_models: Optional[List[ModelWrapper]] = None,
):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.judge_model = self.auxiliary_models[0] # Use the first auxiliary model as the judge
self.judge_model = self.auxiliary_models[0] # OpenAI client auto-derived from ModelWrapper

def run(self) -> List[Experience]:
response = self.do_something()
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/example_react.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class AgentScopeReActWorkflow(Workflow):
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
auxiliary_models: Optional[List[ModelWrapper]] = None,
):
# initialize the agent
self.agent = AgentScopeReActAgent(
Expand Down
1 change: 1 addition & 0 deletions docs/sphinx_doc/source_zh/main.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能:
| AsymRE [[论文](https://arxiv.org/pdf/2506.20520)] | [[GSM8K 例子](https://github.com/modelscope/Trinity-RFT/tree/main/examples/asymre_gsm8k)] | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/asymre_advantage.py)] | `algorithm_type: asymre` |
| CISPO [[论文](https://arxiv.org/pdf/2506.13585)] | - | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` |
| SAPO [[论文](https://arxiv.org/pdf/2511.20347)] | - | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` |
| On-Policy Distillation [[博客](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[论文](https://arxiv.org/pdf/2306.13649)] | [[GSM8K 示例](https://github.com/modelscope/Trinity-RFT/tree/main/examples/on_policy_distill)] | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` |



Expand Down
11 changes: 6 additions & 5 deletions docs/sphinx_doc/source_zh/tutorial/develop_workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,12 @@ class Workflow(ABC):
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None, # 主要用于 LLM-as-a-judge 场景,这里可以忽略
auxiliary_models: Optional[List[ModelWrapper]] = None, # 主要用于 LLM-as-a-judge 场景, 也可以用作distillation的techer
):
self.task = task
self.model = model
self.auxiliary_models = auxiliary_models
self.auxiliary_model_wrappers = auxiliary_models
self.auxiliary_models = ... # 从 ModelWrapper 自动派生的 OpenAI client

@abstractmethod
def run(self) -> List[Experience]:
Expand All @@ -109,7 +110,7 @@ class Workflow(ABC):

- `task`({class}`trinity.common.workflows.Task`):数据集中的单个任务。
- `model`({class}`trinity.common.models.model.ModelWrapper`):正在训练的模型,提供类似于 OpenAI 的接口,能够接收对话消息列表并返回 LLM 生成的内容(包括回复文本 `response_text`、完整序列 token id `tokens`、prompt 部分 token 长度 `prompt_length`,以及输出 token 对数概率列表 `logprobs`)。
- `auxiliary_models`(`List[openai.OpenAI]`):未参与训练的辅助模型列表。所有模型均通过兼容 OpenAI 的 API 提供,主要用于 LLM-as-a-judge 场景
- `auxiliary_models`(`List[ModelWrapper]`):辅助模型的 ModelWrapper 列表。可通过 `self.auxiliary_models` 访问 OpenAI client(根据 workflow 的 `is_async` 自动派生)

以下是一个仅使用 `raw_task` 和 `rollout_args` 初始化简单工作流的示例。在更复杂的情况下,你可以使用 `format_args` 进行进一步自定义。

Expand Down Expand Up @@ -437,10 +438,10 @@ class MyWorkflow(Workflow):
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
auxiliary_models: Optional[List[ModelWrapper]] = None,
):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.judge_model = self.auxiliary_models[0] # 使用第一个辅助模型作为评判者
self.judge_model = self.auxiliary_models[0] # 从 ModelWrapper 自动派生的 OpenAI client

def run(self) -> List[Experience]:
response = self.do_something()
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source_zh/tutorial/example_react.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class AgentScopeReActWorkflow(Workflow):
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
auxiliary_models: Optional[List[ModelWrapper]] = None,
):
# initialize the agent
self.agent = AgentScopeReActAgent(
Expand Down
4 changes: 1 addition & 3 deletions examples/learn_to_ask/workflow/workflow_learn2ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import time
from typing import List, Optional

import openai

from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.workflows.workflow import SimpleWorkflow, Task
Expand Down Expand Up @@ -36,7 +34,7 @@ def __init__(
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
auxiliary_models: Optional[List[ModelWrapper]] = None,
):
self.train_mode = task.workflow_args.get("train_mode", "Ra+Rs")
self.fusion_mode = task.workflow_args.get("fusion_mode", "default")
Expand Down
36 changes: 36 additions & 0 deletions examples/opd_gsm8k/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Example: On-Policy Distillation on GSM8K dataset

This example demonstrates On-Policy Distillation (OPD) algorithm training on the GSM8K dataset.

On-Policy Distillation is a knowledge distillation method, where in this example:
1. **Student model** (`Qwen/Qwen2.5-1.5B-Instruct`) generates trajectories with logprobs
2. **Teacher model** (`Qwen/Qwen2.5-Math-7B-Instruct`) computes logprobs on the same trajectories
3. The advantage is computed as: `advantages = kl_coef * (teacher_logprobs - student_logprobs)`
4. The student model is trained to minimize this KL divergence, effectively learning from the teacher

## Key Configuration

- **Algorithm**: `on_policy_distill`
- **Workflow**: `on_policy_distill_workflow`
- **Student Model**: `Qwen/Qwen2.5-1.5B-Instruct`
- **Teacher Model**: `Qwen/Qwen2.5-Math-7B-Instruct` (configured as auxiliary model)

## Running the Example

Download the model checkpoint and modify your config file, then run:
```bash
trinity run examples/opd_gsm8k/opd_gsm8k.yaml
```

Then you are all set! It should be pretty simple😄, and the training should converge very quick.



![](../../docs/sphinx_doc/assets/opd_acc.png)
![](../../docs/sphinx_doc/assets/opd_kl.png)


## References

- https://arxiv.org/pdf/2306.13649
- https://thinkingmachines.ai/blog/on-policy-distillation/
74 changes: 74 additions & 0 deletions examples/opd_gsm8k/opd_gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
project: "Trinity-RFT-gsm8k-opd"
name: "qwen2.5-1.5B-distill-from-math-7B-lr1e-5"
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: on_policy_distill
repeat_times: 8
optimizer:
lr: 1e-5
advantage_fn_args:
kl_coef: 1.0
model:
# Student model
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
max_response_tokens: 1024
max_model_len: 2048
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 1
batch_size: 96
explorer_input:
taskset:
name: gsm8k
storage_type: file
path: ${oc.env:TRINITY_TASKSET_PATH,openai/gsm8k}
subset_name: main
split: train
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
temperature: 1.0
# Use on_policy_distill_math_workflow for Qwen2.5-Math style format with accuracy reward
default_workflow_type: 'on_policy_distill_math_workflow'
trainer_input:
experience_buffer:
name: gsm8k_opd_buffer
storage_type: queue
explorer:
eval_interval: 50
runner_per_model: 8
rollout_model:
# Student model for rollout
engine_num: 4
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
auxiliary_models:
# Teacher model for distillation
- model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-Math-7B-Instruct}
engine_num: 1
tensor_parallel_size: 2
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
max_model_len: 4096
max_prompt_tokens: 2048
max_response_tokens: 1024
synchronizer:
sync_method: 'nccl'
sync_interval: 1
sync_timeout: 1200
trainer:
save_interval: 100
grad_clip: 1.0
use_dynamic_bsz: true
max_token_len_per_gpu: 16384
ulysses_sequence_parallel_size: 1
monitor:
monitor_type: wandb
14 changes: 8 additions & 6 deletions tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ def __init__(self, model, task: Task, auxiliary_models=None):
self.obj = task.raw_task
self.output_format = task.workflow_args["output_format"]
self.repeat_times = task.rollout_args.n
if auxiliary_models is not None:
for model in auxiliary_models:
assert isinstance(model, openai.OpenAI)
# Check self.auxiliary_models (OpenAI clients derived from ModelWrapper)
if self.auxiliary_models is not None:
for m in self.auxiliary_models:
assert isinstance(m, openai.OpenAI)

def reset(self, task: Task):
self.obj = task.raw_task
Expand Down Expand Up @@ -92,9 +93,10 @@ def __init__(self, model, task: Task, auxiliary_models=None):
self.obj = task.raw_task
self.output_format = task.workflow_args["output_format"]
self.repeat_times = task.rollout_args.n
if auxiliary_models is not None:
for model in auxiliary_models:
assert isinstance(model, openai.AsyncOpenAI)
# Check self.auxiliary_models (AsyncOpenAI clients derived from ModelWrapper)
if self.auxiliary_models is not None:
for m in self.auxiliary_models:
assert isinstance(m, openai.AsyncOpenAI)

def reset(self, task: Task):
self.obj = task.raw_task
Expand Down
1 change: 1 addition & 0 deletions trinity/algorithm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"sppo": "trinity.algorithm.algorithm.sPPOAlgorithm",
"rec": "trinity.algorithm.algorithm.RECAlgorithm",
"multi_step_grpo": "trinity.algorithm.algorithm.MultiStepGRPOAlgorithm",
"on_policy_distill": "trinity.algorithm.algorithm.OnPolicyDistillAlgorithm",
},
)

Expand Down
1 change: 1 addition & 0 deletions trinity/algorithm/advantage_fn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"asymre": "trinity.algorithm.advantage_fn.asymre_advantage.ASYMREGroupAdvantage",
"asymre_verl": "trinity.algorithm.advantage_fn.asymre_advantage.ASYMREAdvantageFn",
"rec": "trinity.algorithm.advantage_fn.rec_advantage.RECGroupedAdvantage",
"on_policy_distill": "trinity.algorithm.advantage_fn.on_policy_distill_advantage.OnPolicyDistillAdvantage",
},
)

Expand Down
68 changes: 68 additions & 0 deletions trinity/algorithm/advantage_fn/on_policy_distill_advantage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
"""On-Policy Distillation advantage computation.

Reference: Tinker library's on-policy distillation.

advantages = -(student_logprobs - teacher_logprobs)
= teacher_logprobs - student_logprobs
"""

from typing import Dict, Tuple

from verl import DataProto

from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn


class OnPolicyDistillAdvantage(AdvantageFn):
"""Advantage function for on-policy distillation.

Computes: advantages = kl_coef * (teacher_logprobs - student_logprobs)

The teacher_logprobs should be stored in Experience.teacher_logprobs
by the workflow during exploration.
"""

def __init__(self, kl_coef: float = 1.0) -> None:
self.kl_coef = kl_coef

def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]:
"""Compute advantages from teacher and student logprobs.

Args:
exps: DataProto containing:
- old_log_probs: student's sampling logprobs [batch, seq]
- teacher_log_probs: teacher's logprobs [batch, seq]
- response_mask: mask for response tokens [batch, seq]

Returns:
exps: DataProto with advantages and returns added
metrics: Dict with kl and advantage statistics
"""
metrics = {}

old_log_probs = exps.batch["old_log_probs"] # student sampling logprobs
teacher_log_probs = exps.batch["teacher_log_probs"]
response_mask = exps.batch["response_mask"]

# advantages = -(student - teacher) = teacher - student
advantages = self.kl_coef * (teacher_log_probs - old_log_probs)

# Apply mask
advantages = advantages * response_mask

exps.batch["advantages"] = advantages
exps.batch["returns"] = advantages.clone()

# Metrics
kl_per_token = old_log_probs - teacher_log_probs
kl_sum = (kl_per_token * response_mask).sum(dim=-1)
metrics["kl/mean"] = kl_sum.mean().item()
metrics["kl/std"] = kl_sum.std().item() if kl_sum.numel() > 1 else 0.0
metrics["advantages/mean"] = advantages.sum(dim=-1).mean().item()

return exps, metrics

@classmethod
def default_args(cls) -> Dict:
return {"kl_coef": 1.0}
Loading