diff --git a/README.md b/README.md index 2b8449df1e..58526defa1 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,7 @@ We list some algorithms supported by Trinity-RFT in the following table. For mor | AsymRE [[Paper](https://arxiv.org/pdf/2506.20520)] | [[GSM8K Example](https://github.com/modelscope/Trinity-RFT/tree/main/examples/asymre_gsm8k)] | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/asymre_advantage.py)] | `algorithm_type: asymre` | | CISPO [[Paper](https://arxiv.org/pdf/2506.13585)] | - | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` | | SAPO [[Paper](https://arxiv.org/pdf/2511.20347)] | - | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` | +| On-Policy Distillation [[Blog](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[Paper](https://arxiv.org/pdf/2306.13649)] | [[GSM8K Example](https://github.com/modelscope/Trinity-RFT/tree/main/examples/on_policy_distill)] | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` | @@ -142,7 +143,7 @@ We list some algorithms supported by Trinity-RFT in the following table. For mor - [Step 2: prepare dataset and model](#step-2-prepare-dataset-and-model) - [Step 3: configurations](#step-3-configurations) - [Step 4: run the RFT process](#step-4-run-the-rft-process) -- [Contribution guide](#contribution-guide) +- [Contribution Guide](#contribution-guide) - [Acknowledgements](#acknowledgements) - [Citation](#citation) diff --git a/README_zh.md b/README_zh.md index 63839442e1..fd90f4af51 100644 --- a/README_zh.md +++ b/README_zh.md @@ -129,6 +129,7 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能: | AsymRE [[论文](https://arxiv.org/pdf/2506.20520)] | [[GSM8K 例子](https://github.com/modelscope/Trinity-RFT/tree/main/examples/asymre_gsm8k)] | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/asymre_advantage.py)] | `algorithm_type: asymre` | | CISPO [[论文](https://arxiv.org/pdf/2506.13585)] | - | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` | | SAPO [[论文](https://arxiv.org/pdf/2511.20347)] | - | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` | +| On-Policy Distillation [[博客](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[论文](https://arxiv.org/pdf/2306.13649)] | [[GSM8K 示例](https://github.com/modelscope/Trinity-RFT/tree/main/examples/on_policy_distill)] | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` | diff --git a/benchmark/reports/gsm8k.md b/benchmark/reports/gsm8k.md index 0c43309255..2234e22661 100644 --- a/benchmark/reports/gsm8k.md +++ b/benchmark/reports/gsm8k.md @@ -188,7 +188,7 @@ class VerlGSM8kWorkflow(Workflow): *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): self.reset(task) super().__init__( diff --git a/docs/sphinx_doc/assets/opd_acc.png b/docs/sphinx_doc/assets/opd_acc.png new file mode 100644 index 0000000000..1815ef47b5 Binary files /dev/null and b/docs/sphinx_doc/assets/opd_acc.png differ diff --git a/docs/sphinx_doc/assets/opd_kl.png b/docs/sphinx_doc/assets/opd_kl.png new file mode 100644 index 0000000000..3f8bcb271c Binary files /dev/null and b/docs/sphinx_doc/assets/opd_kl.png differ diff --git a/docs/sphinx_doc/source/main.md b/docs/sphinx_doc/source/main.md index 919eca2908..97a0f276d2 100644 --- a/docs/sphinx_doc/source/main.md +++ b/docs/sphinx_doc/source/main.md @@ -86,6 +86,8 @@ We list some algorithms supported by Trinity-RFT in the following table. For mor | AsymRE [[Paper](https://arxiv.org/pdf/2506.20520)] | [[GSM8K Example](https://github.com/modelscope/Trinity-RFT/tree/main/examples/asymre_gsm8k)] | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/asymre_advantage.py)] | `algorithm_type: asymre` | | CISPO [[Paper](https://arxiv.org/pdf/2506.13585)] | - | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` | | SAPO [[Paper](https://arxiv.org/pdf/2511.20347)] | - | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` | +| On-Policy Distillation [[Blog](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[Paper](https://arxiv.org/pdf/2306.13649)] | [[GSM8K Example](https://github.com/modelscope/Trinity-RFT/tree/main/examples/on_policy_distill)] | [[Code](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` | + diff --git a/docs/sphinx_doc/source/tutorial/develop_workflow.md b/docs/sphinx_doc/source/tutorial/develop_workflow.md index f8b4551edd..9d7dc20411 100644 --- a/docs/sphinx_doc/source/tutorial/develop_workflow.md +++ b/docs/sphinx_doc/source/tutorial/develop_workflow.md @@ -93,11 +93,12 @@ class Workflow(ABC): *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): self.task = task self.model = model - self.auxiliary_models = auxiliary_models + self.auxiliary_model_wrappers = auxiliary_models + self.auxiliary_models = ... # OpenAI clients auto-derived from ModelWrapper @abstractmethod def run(self) -> List[Experience]: @@ -110,7 +111,7 @@ During initialization, `Workflow` receives the following parameters: - `task`({class}`trinity.common.workflows.Task`): A single data item from the task dataset. - `model`({class}`trinity.common.models.model.ModelWrapper`): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`). -- `auxiliary_models`(`List[openai.OpenAI]`):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs. +- `auxiliary_models`(`List[ModelWrapper]`): A list of auxiliary model wrappers. You can access OpenAI clients via `self.auxiliary_models` (auto-derived based on workflow's `is_async`). ```{tip} You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow. @@ -440,10 +441,10 @@ class MyWorkflow(Workflow): *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) - self.judge_model = self.auxiliary_models[0] # Use the first auxiliary model as the judge + self.judge_model = self.auxiliary_models[0] # OpenAI client auto-derived from ModelWrapper def run(self) -> List[Experience]: response = self.do_something() diff --git a/docs/sphinx_doc/source/tutorial/example_react.md b/docs/sphinx_doc/source/tutorial/example_react.md index 4264696ab9..eeaad84a04 100644 --- a/docs/sphinx_doc/source/tutorial/example_react.md +++ b/docs/sphinx_doc/source/tutorial/example_react.md @@ -82,7 +82,7 @@ class AgentScopeReActWorkflow(Workflow): *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): # initialize the agent self.agent = AgentScopeReActAgent( diff --git a/docs/sphinx_doc/source_zh/main.md b/docs/sphinx_doc/source_zh/main.md index 3d1c673e25..9e490f846b 100644 --- a/docs/sphinx_doc/source_zh/main.md +++ b/docs/sphinx_doc/source_zh/main.md @@ -82,6 +82,7 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能: | AsymRE [[论文](https://arxiv.org/pdf/2506.20520)] | [[GSM8K 例子](https://github.com/modelscope/Trinity-RFT/tree/main/examples/asymre_gsm8k)] | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/asymre_advantage.py)] | `algorithm_type: asymre` | | CISPO [[论文](https://arxiv.org/pdf/2506.13585)] | - | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` | | SAPO [[论文](https://arxiv.org/pdf/2511.20347)] | - | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` | +| On-Policy Distillation [[博客](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[论文](https://arxiv.org/pdf/2306.13649)] | [[GSM8K 示例](https://github.com/modelscope/Trinity-RFT/tree/main/examples/on_policy_distill)] | [[代码](https://github.com/modelscope/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` | diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md b/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md index a3bb025cfc..3e8cbd3133 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md @@ -92,11 +92,12 @@ class Workflow(ABC): *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, # 主要用于 LLM-as-a-judge 场景,这里可以忽略 + auxiliary_models: Optional[List[ModelWrapper]] = None, # 主要用于 LLM-as-a-judge 场景, 也可以用作distillation的techer ): self.task = task self.model = model - self.auxiliary_models = auxiliary_models + self.auxiliary_model_wrappers = auxiliary_models + self.auxiliary_models = ... # 从 ModelWrapper 自动派生的 OpenAI client @abstractmethod def run(self) -> List[Experience]: @@ -109,7 +110,7 @@ class Workflow(ABC): - `task`({class}`trinity.common.workflows.Task`):数据集中的单个任务。 - `model`({class}`trinity.common.models.model.ModelWrapper`):正在训练的模型,提供类似于 OpenAI 的接口,能够接收对话消息列表并返回 LLM 生成的内容(包括回复文本 `response_text`、完整序列 token id `tokens`、prompt 部分 token 长度 `prompt_length`,以及输出 token 对数概率列表 `logprobs`)。 -- `auxiliary_models`(`List[openai.OpenAI]`):未参与训练的辅助模型列表。所有模型均通过兼容 OpenAI 的 API 提供,主要用于 LLM-as-a-judge 场景。 +- `auxiliary_models`(`List[ModelWrapper]`):辅助模型的 ModelWrapper 列表。可通过 `self.auxiliary_models` 访问 OpenAI client(根据 workflow 的 `is_async` 自动派生)。 以下是一个仅使用 `raw_task` 和 `rollout_args` 初始化简单工作流的示例。在更复杂的情况下,你可以使用 `format_args` 进行进一步自定义。 @@ -437,10 +438,10 @@ class MyWorkflow(Workflow): *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) - self.judge_model = self.auxiliary_models[0] # 使用第一个辅助模型作为评判者 + self.judge_model = self.auxiliary_models[0] # 从 ModelWrapper 自动派生的 OpenAI client def run(self) -> List[Experience]: response = self.do_something() diff --git a/docs/sphinx_doc/source_zh/tutorial/example_react.md b/docs/sphinx_doc/source_zh/tutorial/example_react.md index ab6b6a1d73..ec71d455e2 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_react.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_react.md @@ -88,7 +88,7 @@ class AgentScopeReActWorkflow(Workflow): *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): # initialize the agent self.agent = AgentScopeReActAgent( diff --git a/examples/learn_to_ask/workflow/workflow_learn2ask.py b/examples/learn_to_ask/workflow/workflow_learn2ask.py index 6e5cb6e0da..f8b4b49c43 100644 --- a/examples/learn_to_ask/workflow/workflow_learn2ask.py +++ b/examples/learn_to_ask/workflow/workflow_learn2ask.py @@ -7,8 +7,6 @@ import time from typing import List, Optional -import openai - from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.common.workflows.workflow import SimpleWorkflow, Task @@ -36,7 +34,7 @@ def __init__( *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): self.train_mode = task.workflow_args.get("train_mode", "Ra+Rs") self.fusion_mode = task.workflow_args.get("fusion_mode", "default") diff --git a/examples/opd_gsm8k/README.md b/examples/opd_gsm8k/README.md new file mode 100644 index 0000000000..eb0f9d6dad --- /dev/null +++ b/examples/opd_gsm8k/README.md @@ -0,0 +1,36 @@ +# Example: On-Policy Distillation on GSM8K dataset + +This example demonstrates On-Policy Distillation (OPD) algorithm training on the GSM8K dataset. + +On-Policy Distillation is a knowledge distillation method, where in this example: +1. **Student model** (`Qwen/Qwen2.5-1.5B-Instruct`) generates trajectories with logprobs +2. **Teacher model** (`Qwen/Qwen2.5-Math-7B-Instruct`) computes logprobs on the same trajectories +3. The advantage is computed as: `advantages = kl_coef * (teacher_logprobs - student_logprobs)` +4. The student model is trained to minimize this KL divergence, effectively learning from the teacher + +## Key Configuration + +- **Algorithm**: `on_policy_distill` +- **Workflow**: `on_policy_distill_workflow` +- **Student Model**: `Qwen/Qwen2.5-1.5B-Instruct` +- **Teacher Model**: `Qwen/Qwen2.5-Math-7B-Instruct` (configured as auxiliary model) + +## Running the Example + +Download the model checkpoint and modify your config file, then run: +```bash +trinity run examples/opd_gsm8k/opd_gsm8k.yaml +``` + +Then you are all set! It should be pretty simple😄, and the training should converge very quick. + + + +![](../../docs/sphinx_doc/assets/opd_acc.png) +![](../../docs/sphinx_doc/assets/opd_kl.png) + + +## References + +- https://arxiv.org/pdf/2306.13649 +- https://thinkingmachines.ai/blog/on-policy-distillation/ diff --git a/examples/opd_gsm8k/opd_gsm8k.yaml b/examples/opd_gsm8k/opd_gsm8k.yaml new file mode 100644 index 0000000000..2e75e8c232 --- /dev/null +++ b/examples/opd_gsm8k/opd_gsm8k.yaml @@ -0,0 +1,74 @@ +project: "Trinity-RFT-gsm8k-opd" +name: "qwen2.5-1.5B-distill-from-math-7B-lr1e-5" +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +algorithm: + algorithm_type: on_policy_distill + repeat_times: 8 + optimizer: + lr: 1e-5 + advantage_fn_args: + kl_coef: 1.0 +model: + # Student model + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct} + max_response_tokens: 1024 + max_model_len: 2048 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 1 + batch_size: 96 + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: ${oc.env:TRINITY_TASKSET_PATH,openai/gsm8k} + subset_name: main + split: train + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 1.0 + # Use on_policy_distill_math_workflow for Qwen2.5-Math style format with accuracy reward + default_workflow_type: 'on_policy_distill_math_workflow' + trainer_input: + experience_buffer: + name: gsm8k_opd_buffer + storage_type: queue +explorer: + eval_interval: 50 + runner_per_model: 8 + rollout_model: + # Student model for rollout + engine_num: 4 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 + auxiliary_models: + # Teacher model for distillation + - model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-Math-7B-Instruct} + engine_num: 1 + tensor_parallel_size: 2 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 + max_model_len: 4096 + max_prompt_tokens: 2048 + max_response_tokens: 1024 +synchronizer: + sync_method: 'nccl' + sync_interval: 1 + sync_timeout: 1200 +trainer: + save_interval: 100 + grad_clip: 1.0 + use_dynamic_bsz: true + max_token_len_per_gpu: 16384 + ulysses_sequence_parallel_size: 1 +monitor: + monitor_type: wandb diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 16bb0b7bbb..71a87dea96 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -48,9 +48,10 @@ def __init__(self, model, task: Task, auxiliary_models=None): self.obj = task.raw_task self.output_format = task.workflow_args["output_format"] self.repeat_times = task.rollout_args.n - if auxiliary_models is not None: - for model in auxiliary_models: - assert isinstance(model, openai.OpenAI) + # Check self.auxiliary_models (OpenAI clients derived from ModelWrapper) + if self.auxiliary_models is not None: + for m in self.auxiliary_models: + assert isinstance(m, openai.OpenAI) def reset(self, task: Task): self.obj = task.raw_task @@ -92,9 +93,10 @@ def __init__(self, model, task: Task, auxiliary_models=None): self.obj = task.raw_task self.output_format = task.workflow_args["output_format"] self.repeat_times = task.rollout_args.n - if auxiliary_models is not None: - for model in auxiliary_models: - assert isinstance(model, openai.AsyncOpenAI) + # Check self.auxiliary_models (AsyncOpenAI clients derived from ModelWrapper) + if self.auxiliary_models is not None: + for m in self.auxiliary_models: + assert isinstance(m, openai.AsyncOpenAI) def reset(self, task: Task): self.obj = task.raw_task diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py index 3b347b26e5..979d08a779 100644 --- a/trinity/algorithm/__init__.py +++ b/trinity/algorithm/__init__.py @@ -27,6 +27,7 @@ "sppo": "trinity.algorithm.algorithm.sPPOAlgorithm", "rec": "trinity.algorithm.algorithm.RECAlgorithm", "multi_step_grpo": "trinity.algorithm.algorithm.MultiStepGRPOAlgorithm", + "on_policy_distill": "trinity.algorithm.algorithm.OnPolicyDistillAlgorithm", }, ) diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py index 741a68fa6f..3b4dfe887e 100644 --- a/trinity/algorithm/advantage_fn/__init__.py +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -17,6 +17,7 @@ "asymre": "trinity.algorithm.advantage_fn.asymre_advantage.ASYMREGroupAdvantage", "asymre_verl": "trinity.algorithm.advantage_fn.asymre_advantage.ASYMREAdvantageFn", "rec": "trinity.algorithm.advantage_fn.rec_advantage.RECGroupedAdvantage", + "on_policy_distill": "trinity.algorithm.advantage_fn.on_policy_distill_advantage.OnPolicyDistillAdvantage", }, ) diff --git a/trinity/algorithm/advantage_fn/on_policy_distill_advantage.py b/trinity/algorithm/advantage_fn/on_policy_distill_advantage.py new file mode 100644 index 0000000000..d4265552bc --- /dev/null +++ b/trinity/algorithm/advantage_fn/on_policy_distill_advantage.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +"""On-Policy Distillation advantage computation. + +Reference: Tinker library's on-policy distillation. + +advantages = -(student_logprobs - teacher_logprobs) + = teacher_logprobs - student_logprobs +""" + +from typing import Dict, Tuple + +from verl import DataProto + +from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn + + +class OnPolicyDistillAdvantage(AdvantageFn): + """Advantage function for on-policy distillation. + + Computes: advantages = kl_coef * (teacher_logprobs - student_logprobs) + + The teacher_logprobs should be stored in Experience.teacher_logprobs + by the workflow during exploration. + """ + + def __init__(self, kl_coef: float = 1.0) -> None: + self.kl_coef = kl_coef + + def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]: + """Compute advantages from teacher and student logprobs. + + Args: + exps: DataProto containing: + - old_log_probs: student's sampling logprobs [batch, seq] + - teacher_log_probs: teacher's logprobs [batch, seq] + - response_mask: mask for response tokens [batch, seq] + + Returns: + exps: DataProto with advantages and returns added + metrics: Dict with kl and advantage statistics + """ + metrics = {} + + old_log_probs = exps.batch["old_log_probs"] # student sampling logprobs + teacher_log_probs = exps.batch["teacher_log_probs"] + response_mask = exps.batch["response_mask"] + + # advantages = -(student - teacher) = teacher - student + advantages = self.kl_coef * (teacher_log_probs - old_log_probs) + + # Apply mask + advantages = advantages * response_mask + + exps.batch["advantages"] = advantages + exps.batch["returns"] = advantages.clone() + + # Metrics + kl_per_token = old_log_probs - teacher_log_probs + kl_sum = (kl_per_token * response_mask).sum(dim=-1) + metrics["kl/mean"] = kl_sum.mean().item() + metrics["kl/std"] = kl_sum.std().item() if kl_sum.numel() > 1 else 0.0 + metrics["advantages/mean"] = advantages.sum(dim=-1).mean().item() + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return {"kl_coef": 1.0} diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index b08da64cdc..cd8fce5f35 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -459,3 +459,35 @@ def default_config(cls) -> Dict: "kl_loss_fn": "k2", "entropy_loss_fn": "default", } + + +class OnPolicyDistillAlgorithm(AlgorithmType): + """On-Policy Distillation Algorithm. + + Reference: Tinker library. + + Workflow stores teacher_logprobs in experience.info["teacher_logprobs"]. + Trainer's advantage_fn computes: advantages = teacher_logprobs - student_logprobs + Trainer uses: + importance_sampling loss if no clipping is needed + ppo loss if clipping is needed, for better stability + """ + + use_critic: bool = False + use_reference: bool = False + compute_advantage_in_trainer: bool = True # advantage_fn computes from teacher_logprobs + can_balance_batch: bool = True + schema: str = "experience" + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 8, + "advantage_fn": "on_policy_distill", + "advantage_fn_args": {"kl_coef": 1.0}, + "sample_strategy": "default", + "policy_loss_fn": "ppo", # or importance_sampling if no clipping is needed + "kl_penalty_fn": "none", + "kl_loss_fn": "none", + "entropy_loss_fn": "none", + } diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py index 5d30e9f4f6..6d1b0a2465 100644 --- a/trinity/algorithm/policy_loss_fn/__init__.py +++ b/trinity/algorithm/policy_loss_fn/__init__.py @@ -18,6 +18,7 @@ "sppo": "trinity.algorithm.policy_loss_fn.sppo_loss_fn.sPPOPolicyLossFn", "rec": "trinity.algorithm.policy_loss_fn.rec_policy_loss.RECPolicyLossFn", "sapo": "trinity.algorithm.policy_loss_fn.sapo_policy_loss.SAPOPolicyLossFn", + "importance_sampling": "trinity.algorithm.policy_loss_fn.importance_sampling_policy_loss.ImportanceSamplingLossFn", }, ) diff --git a/trinity/algorithm/policy_loss_fn/importance_sampling_policy_loss.py b/trinity/algorithm/policy_loss_fn/importance_sampling_policy_loss.py new file mode 100644 index 0000000000..4f273b3780 --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/importance_sampling_policy_loss.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +"""The most simple Importance Sampling policy loss. + +loss = -(prob_ratio * advantages).sum() +where prob_ratio = exp(current_logprobs - sampling_logprobs) + +Note: This loss is used for on-policy distillation. +""" + +from typing import Dict, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn +from trinity.algorithm.utils import aggregate_loss, masked_mean + + +class ImportanceSamplingLossFn(PolicyLossFn): + """Pure importance sampling loss without clipping. + + loss = -(ratio * advantages) + where ratio = exp(logprob - old_logprob) + """ + + def __init__( + self, + backend: str = "verl", + loss_agg_mode: str = "token-mean", + ) -> None: + super().__init__(backend=backend) + self.loss_agg_mode = loss_agg_mode + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + # prob_ratio = exp(current_logprobs - sampling_logprobs) + log_ratio = logprob - old_logprob + log_ratio = torch.clamp(log_ratio, min=-20.0, max=20.0) + ratio = torch.exp(log_ratio) + + # loss = -(prob_ratio * advantages) + pg_losses = -advantages * ratio + pg_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode) + + metrics = { + "pg_loss": pg_loss.detach().item(), + "ratio/mean": masked_mean(ratio, action_mask).detach().item(), + "approx_kl": masked_mean(-log_ratio, action_mask).detach().item(), + } + + return pg_loss, metrics + + @classmethod + def default_args(cls) -> Dict: + return {"loss_agg_mode": "token-mean"} diff --git a/trinity/common/config.py b/trinity/common/config.py index 883e71752d..cf7a5a04d1 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -80,7 +80,7 @@ class FormatConfig: class GenerationConfig: temperature: Optional[float] = None # 1.0 top_p: Optional[float] = None # 1.0 - top_k: Optional[int] = None # -1 + top_k: int = -1 # -1 means disabled logprobs: Optional[int] = None # 0 # vLLM return `logprobs + 1` elements max_tokens: Optional[int] = None # if None, use model.max_response_tokens # repeat each task for `n` times diff --git a/trinity/common/experience.py b/trinity/common/experience.py index 2a734144d2..0f94e0bdd5 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -133,6 +133,9 @@ class Experience: # for multi-modal data multi_modal_inputs: Optional[Dict[str, Tensor]] = None # Multi-modal inputs for verl trainer + # for on-policy distillation + teacher_logprobs: Optional[Tensor] = None # [resp_length] + def __init__( # noqa: C901 self, *, @@ -157,6 +160,7 @@ def __init__( # noqa: C901 chosen_messages=None, rejected_messages=None, multi_modal_inputs=None, + teacher_logprobs=None, ): if action_mask is not None: experience_type = "multi_turn" @@ -229,6 +233,11 @@ def __init__( # noqa: C901 else: self.multi_modal_inputs[key] = value + # Handle teacher_logprobs + if isinstance(teacher_logprobs, list): + teacher_logprobs = torch.tensor(teacher_logprobs, dtype=torch.float32) + self.teacher_logprobs = teacher_logprobs + if not isinstance(self.tokens, Tensor): self.tokens = torch.tensor(self.tokens) if self.logprobs is not None and not isinstance(self.logprobs, Tensor): @@ -239,6 +248,8 @@ def __init__( # noqa: C901 self.chosen = torch.tensor(self.chosen) if self.rejected is not None and not isinstance(self.rejected, Tensor): self.rejected = torch.tensor(self.rejected) + if self.teacher_logprobs is not None and not isinstance(self.teacher_logprobs, Tensor): + self.teacher_logprobs = torch.tensor(self.teacher_logprobs, dtype=torch.float32) def serialize(self) -> bytes: """Serialize the experience to bytes.""" @@ -341,6 +352,14 @@ def gather( else: multi_modal_inputs = None + # gather teacher_logprobs + if all(exp.teacher_logprobs is not None for exp in experiences): + teacher_logprobs = gather_response_attrs( + experiences, "teacher_logprobs", max_response_length + ) + else: + teacher_logprobs = None + exps = Experiences( eids=eids, tokens=tokens, @@ -353,6 +372,7 @@ def gather( prompt_length=max_prompt_length, logprobs=logprobs, multi_modal_inputs=multi_modal_inputs, + teacher_logprobs=teacher_logprobs, ) if custom_fields is not None: for custom_field in custom_fields: @@ -442,6 +462,7 @@ class Experiences: custom_fields: List[str] = field( default_factory=list ) # Custom fields to include in the gathered experiences + teacher_logprobs: Optional[Tensor] = None # [batch_size, response_length] @property def batch_size(self) -> int: diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index 4aeca67f5b..b8496f8f41 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -44,6 +44,9 @@ # others "simple_mm_workflow": "trinity.common.workflows.simple_mm_workflow.SimpleMMWorkflow", "async_simple_mm_workflow": "trinity.common.workflows.simple_mm_workflow.AsyncSimpleMMWorkflow", + # on-policy distillation workflows + "on_policy_distill_workflow": "trinity.common.workflows.on_policy_distill_workflow.OnPolicyDistillWorkflow", + "on_policy_distill_math_workflow": "trinity.common.workflows.on_policy_distill_workflow.OnPolicyDistillMathWorkflow", }, ) diff --git a/trinity/common/workflows/agentscope/react/react_workflow.py b/trinity/common/workflows/agentscope/react/react_workflow.py index f4c8c4375d..203af35623 100644 --- a/trinity/common/workflows/agentscope/react/react_workflow.py +++ b/trinity/common/workflows/agentscope/react/react_workflow.py @@ -5,8 +5,6 @@ from typing import Dict, List, Optional, Union -import openai - from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.common.workflows.workflow import Task, Workflow @@ -22,7 +20,7 @@ def __init__( *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): super().__init__( task=task, diff --git a/trinity/common/workflows/agentscope_workflow.py b/trinity/common/workflows/agentscope_workflow.py index 8fc28fbe32..0de1c41c2a 100644 --- a/trinity/common/workflows/agentscope_workflow.py +++ b/trinity/common/workflows/agentscope_workflow.py @@ -1,7 +1,5 @@ from typing import Awaitable, Callable, Dict, List, Optional -import openai - from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.common.workflows.workflow import Task, Workflow @@ -17,7 +15,7 @@ def __init__( *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): """Initialize the adapter with the task and model.""" try: diff --git a/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py b/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py index fab5403dcd..6b55bd588d 100644 --- a/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py +++ b/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py @@ -3,8 +3,6 @@ from typing import List, Optional -import openai - from trinity.common.models.model import ModelWrapper from trinity.common.rewards.math_reward import MathBoxedRewardFn from trinity.common.workflows.workflow import Task, Workflow @@ -25,7 +23,7 @@ def __init__( *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): super().__init__( task=task, diff --git a/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py b/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py index c8cc0dc155..90ca068e60 100644 --- a/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py +++ b/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py @@ -3,8 +3,6 @@ from typing import List, Optional -import openai - from trinity.common.models.model import ModelWrapper from trinity.common.rewards.math_reward import MathBoxedRewardFn from trinity.common.workflows.workflow import Task, Workflow @@ -24,7 +22,7 @@ def __init__( *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): super().__init__( task=task, diff --git a/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py b/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py index 8d2231b0e7..c24e30d0ce 100644 --- a/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py +++ b/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py @@ -5,8 +5,6 @@ import re from typing import List, Optional -import openai - from trinity.common.models.model import ModelWrapper from trinity.common.workflows.workflow import Task, Workflow @@ -24,7 +22,7 @@ def __init__( *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): # get openai client from model super().__init__( task=task, diff --git a/trinity/common/workflows/envs/email_searcher/workflow.py b/trinity/common/workflows/envs/email_searcher/workflow.py index 0a243ef507..a644e32a3c 100644 --- a/trinity/common/workflows/envs/email_searcher/workflow.py +++ b/trinity/common/workflows/envs/email_searcher/workflow.py @@ -2,8 +2,6 @@ from typing import Dict, List, Optional -import openai - from trinity.common.models.model import ModelWrapper from trinity.common.workflows.envs.email_searcher.utils import ( AnswerModel, @@ -34,7 +32,7 @@ def __init__( *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): # get openai client from model self.openai_client = model.get_openai_async_client() diff --git a/trinity/common/workflows/eval_workflow.py b/trinity/common/workflows/eval_workflow.py index 0fcb115b6f..a0e3e82d5c 100644 --- a/trinity/common/workflows/eval_workflow.py +++ b/trinity/common/workflows/eval_workflow.py @@ -4,8 +4,6 @@ from dataclasses import asdict from typing import List, Optional -import openai - from trinity.common.config import GenerationConfig from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper @@ -27,7 +25,7 @@ def __init__( *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): super().__init__( task=task, diff --git a/trinity/common/workflows/math_rm_workflow.py b/trinity/common/workflows/math_rm_workflow.py index a5213e0034..13e2fd4cb2 100644 --- a/trinity/common/workflows/math_rm_workflow.py +++ b/trinity/common/workflows/math_rm_workflow.py @@ -3,8 +3,6 @@ from typing import List, Optional -import openai - from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.common.workflows.workflow import SimpleWorkflow, Task @@ -18,7 +16,7 @@ def __init__( *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): self.reset(task) super().__init__( diff --git a/trinity/common/workflows/math_ruler_workflow.py b/trinity/common/workflows/math_ruler_workflow.py index 42848dd0d7..11a7a3a8c6 100644 --- a/trinity/common/workflows/math_ruler_workflow.py +++ b/trinity/common/workflows/math_ruler_workflow.py @@ -3,8 +3,6 @@ import ast from typing import Any, List, Optional, Tuple -import openai - from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.common.rewards.math_reward import MathRewardFn @@ -23,7 +21,7 @@ def __init__( *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): super().__init__( task=task, diff --git a/trinity/common/workflows/math_trainable_ruler_workflow.py b/trinity/common/workflows/math_trainable_ruler_workflow.py index d43cbe4aed..5c01ec499b 100644 --- a/trinity/common/workflows/math_trainable_ruler_workflow.py +++ b/trinity/common/workflows/math_trainable_ruler_workflow.py @@ -5,7 +5,6 @@ from typing import Any, List, Optional, Tuple import numpy as np -import openai from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper @@ -27,7 +26,7 @@ def __init__( *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): super().__init__( task=task, diff --git a/trinity/common/workflows/on_policy_distill_workflow.py b/trinity/common/workflows/on_policy_distill_workflow.py new file mode 100644 index 0000000000..e925103374 --- /dev/null +++ b/trinity/common/workflows/on_policy_distill_workflow.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- +"""On-Policy Distillation Workflow. + +Reference: Tinker library's on-policy distillation implementation. + +Algorithm: +1. Student samples trajectories (with logprobs) +2. Teacher computes logprobs on same trajectories +3. Store teacher_logprobs in experience.info["teacher_logprobs"] +4. Trainer's advantage_fn computes: advantages = teacher_logprobs - student_logprobs +5. Train with importance_sampling loss +""" + +from dataclasses import asdict +from typing import List, Optional + +from trinity.common.experience import Experience +from trinity.common.models.model import ModelWrapper +from trinity.common.rewards.qwen25_eval import verify_math_answer +from trinity.common.workflows.workflow import Task, Workflow + + +class OnPolicyDistillWorkflow(Workflow): + """On-policy distillation workflow. + + Computes and stores teacher_logprobs in experience.info. + The advantage_fn in trainer will compute: + advantages = teacher_logprobs - student_logprobs + + Note: This workflow does NOT use reward_fn because: + - Advantage is computed from teacher-student logprobs difference + - No external reward signal is needed + """ + + is_async: bool = True + can_reset: bool = True + can_repeat: bool = True + + def __init__( + self, + *, + task: Task, + model: ModelWrapper, + auxiliary_models: Optional[List[ModelWrapper]] = None, + ): + super().__init__( + task=task, + model=model, + auxiliary_models=auxiliary_models, + ) + self.reset(task) + + assert ( + self.auxiliary_model_wrappers is not None and len(self.auxiliary_model_wrappers) >= 1 + ), "On-policy distillation requires at least one auxiliary model as teacher." + self.teacher_model = self.auxiliary_model_wrappers[0] + + self.temperature = task.workflow_args.get("temperature", 1.0) + + def reset(self, task: Task): + """Reset the workflow with a new task. + + Unlike BaseSimpleWorkflow, this does NOT require reward_fn. + """ + self.task = task + self.format_args = task.format_args + self.system_prompt = task.format_args.system_prompt + self.reply_prefix = task.format_args.reply_prefix + self.raw_task = task.raw_task + self.task_desc = task.task_desc + self.truth = task.truth + + def set_repeat_times(self, repeat_times, run_id_base): + self.repeat_times = repeat_times + self.task.rollout_args.n = repeat_times + self.run_id_base = run_id_base + + @property + def rollout_args(self): + return asdict(self.task.rollout_args) + + def format_messages(self): + """Format messages for the instruct model. + + Default format: system_prompt (optional) + task_desc + reply_prefix (optional) + """ + messages = [] + if self.system_prompt: + messages.append({"role": "system", "content": self.system_prompt}) + messages.append({"role": "user", "content": self.task_desc}) + if self.reply_prefix: + messages.append({"role": "assistant", "content": self.reply_prefix}) + return messages + + def compute_reward(self, response: Experience) -> float: + """Compute reward for a response. + + In base class, returns 0.0 as advantage is computed from teacher-student logprobs. + Subclasses can override this to compute actual rewards. + """ + return 0.0 + + async def run_async(self) -> List[Experience]: + messages = self.format_messages() + + # Step 1: Student samples trajectories + responses = await self.model.chat_async(messages, **self.rollout_args) + + for i, response in enumerate(responses): + # Step 2: Teacher computes logprobs + teacher_logprobs = await self.teacher_model.logprobs_async( + tokens=response.tokens.tolist(), + temperature=self.temperature, + ) + + # Extract response portion + resp_start = response.prompt_length - 1 + teacher_resp_logprobs = teacher_logprobs[resp_start:] + student_resp_logprobs = response.logprobs + + # Verify lengths match (they should be equal for the same token sequence) + assert len(teacher_resp_logprobs) == len(student_resp_logprobs), ( + f"Length mismatch: teacher_logprobs={len(teacher_resp_logprobs)}, " + f"student_logprobs={len(student_resp_logprobs)}. " + f"tokens={len(response.tokens)}, prompt_length={response.prompt_length}" + ) + + # Step 3: Store teacher_logprobs for advantage_fn + response.teacher_logprobs = teacher_resp_logprobs + + # Initialize metrics + if response.metrics is None: + response.metrics = {} + + # Compute reward (subclasses can override compute_reward) + response.reward = self.compute_reward(response) + + response.eid.run = i + self.run_id_base + + # KL divergence for monitoring + kl = (student_resp_logprobs - teacher_resp_logprobs).sum().item() + response.metrics["kl_divergence"] = kl + + return responses + + +class OnPolicyDistillMathWorkflow(OnPolicyDistillWorkflow): + """On-policy distillation workflow with Qwen2.5-Math style format. + + This workflow: + - Uses Qwen2.5-Math style prompt format (same as math_eval_workflow) + - Computes accuracy using verify_math_answer as reward + - Suitable for math reasoning tasks like GSM8K, MATH, etc. + """ + + def format_messages(self): + """Format messages using Qwen2.5-Math style. + + System prompt: "You are a helpful assistant." + User prompt: "{question}\nPlease reason step by step, and put your final answer within \\boxed{}." + """ + system_prompt = "You are a helpful assistant." + user_prompt = f"{self.task_desc}\nPlease reason step by step, and put your final answer within \\boxed{{}}." + return [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + def compute_reward(self, response: Experience) -> float: + """Compute accuracy as reward using Qwen2.5-Math evaluation. + + Returns 1.0 if answer is correct, 0.0 otherwise. + """ + if response.response_text and self.truth: + accuracy, _ = verify_math_answer( + response_text=response.response_text, ground_truth=self.truth + ) + # Store accuracy in metrics + if response.metrics is None: + response.metrics = {} + response.metrics["accuracy"] = accuracy + return float(accuracy) + return 0.0 diff --git a/trinity/common/workflows/rubric_judge_workflow.py b/trinity/common/workflows/rubric_judge_workflow.py index 2311803cac..455ac0603a 100644 --- a/trinity/common/workflows/rubric_judge_workflow.py +++ b/trinity/common/workflows/rubric_judge_workflow.py @@ -21,7 +21,7 @@ def __init__( *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): super().__init__( task=task, diff --git a/trinity/common/workflows/simple_mm_workflow.py b/trinity/common/workflows/simple_mm_workflow.py index 97044e0cf1..2eca2274fb 100644 --- a/trinity/common/workflows/simple_mm_workflow.py +++ b/trinity/common/workflows/simple_mm_workflow.py @@ -1,7 +1,5 @@ from typing import List, Optional -import openai - from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.common.rewards.reward_fn import RewardFn @@ -16,7 +14,7 @@ def __init__( *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): self.reset(task) super().__init__( diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 7d9e16e689..cf3b4d449b 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import asdict, dataclass, field -from typing import TYPE_CHECKING, Any, List, Optional, Type, Union +from typing import TYPE_CHECKING, List, Optional, Type, Union from trinity.common.config import FormatConfig, GenerationConfig from trinity.common.experience import Experience @@ -38,16 +38,17 @@ class Task(dict): index: dict = field(default_factory=dict) def to_workflow( - self, model: Any, auxiliary_models: Optional[List[openai.OpenAI]] = None + self, + model: ModelWrapper, + auxiliary_models: Optional[List[ModelWrapper]] = None, ) -> Workflow: """Convert the task to a workflow. Args: model (ModelWrapper): The rollout model for the workflow. - auxiliary_models (List[openai.OpenAI]): The auxiliary models for the workflow. - - Note: - `model_path` attribute is added to the `auxiliary_models` for use within the workflow. + auxiliary_models (List[ModelWrapper]): The auxiliary model wrappers. + Workflows can access both the ModelWrapper and OpenAI client via + self.auxiliary_model_wrappers and self.auxiliary_models respectively. Returns: Workflow: The generated workflow object. @@ -78,6 +79,10 @@ class Workflow: """The base workflow class. A workflow is a runnable object which generates a list of experiences. + + Attributes: + auxiliary_model_wrappers: List of ModelWrapper instances for auxiliary models. + auxiliary_models: List of OpenAI clients (sync or async based on is_async) for auxiliary models. """ can_reset: bool = False # whether the workflow can be reset with a new task. If true, `reset()` must be implemented. @@ -89,11 +94,19 @@ def __init__( *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): self.task = task self.model = model - self.auxiliary_models = auxiliary_models + # Store ModelWrapper instances + self.auxiliary_model_wrappers = auxiliary_models + # Get OpenAI clients from ModelWrapper (async or sync based on workflow type) + self.auxiliary_models: Optional[Union[List[openai.OpenAI], List[openai.AsyncOpenAI]]] = None + if auxiliary_models: + if self.__class__.is_async: + self.auxiliary_models = [m.get_openai_async_client() for m in auxiliary_models] + else: + self.auxiliary_models = [m.get_openai_client() for m in auxiliary_models] self.run_id_base = 0 self.logger = get_logger(__name__) @@ -151,7 +164,7 @@ def __init__( *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): super().__init__( task=task, @@ -203,7 +216,7 @@ def __init__( *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): self.reset(task) super().__init__( @@ -315,7 +328,7 @@ def __init__( *, task: Task, model: ModelWrapper, - auxiliary_models: Optional[List[openai.OpenAI]] = None, + auxiliary_models: Optional[List[ModelWrapper]] = None, ): self.reset(task) super().__init__( diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 57c5c25bf0..9edba9014f 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -72,8 +72,6 @@ def __init__( ) for model in self.auxiliary_models ] - self.auxiliary_model_clients = [] - self.auxiliary_model_async_clients = [] self.workflow_instance: Workflow = None self.runner_id = runner_id self.runner_state = { @@ -89,11 +87,6 @@ async def prepare(self) -> None: self.model_wrapper.prepare(), *(aux_model.prepare() for aux_model in self.auxiliary_model_wrappers), ) - for model in self.auxiliary_model_wrappers: - api_client = model.get_openai_client() - async_api_client = model.get_openai_async_client() - self.auxiliary_model_clients.append(api_client) - self.auxiliary_model_async_clients.append(async_api_client) def is_alive(self): return True @@ -106,13 +99,10 @@ def _create_workflow_instance(self, task: Task) -> None: or not self.workflow_instance.__class__ == task.workflow or not self.workflow_instance.resettable ): + # Pass ModelWrapper directly; Workflow.__init__ will get OpenAI clients automatically self.workflow_instance = task.to_workflow( self.model_wrapper, - ( - self.auxiliary_model_async_clients - if task.workflow.is_async - else self.auxiliary_model_clients - ), + self.auxiliary_model_wrappers, ) else: self.workflow_instance.reset(task) diff --git a/trinity/trainer/verl/utils.py b/trinity/trainer/verl/utils.py index ab50b7c877..922182e643 100644 --- a/trinity/trainer/verl/utils.py +++ b/trinity/trainer/verl/utils.py @@ -58,6 +58,8 @@ def to_data_proto(experiences: Experiences, logger: Logger) -> DataProto: # noq batch_dict["advantages"] = experiences.advantages if experiences.returns is not None: batch_dict["returns"] = experiences.returns + if experiences.teacher_logprobs is not None: + batch_dict["teacher_log_probs"] = experiences.teacher_logprobs if experiences.multi_modal_inputs is not None: batch_size = len(batch_dict["unique_ids"])